Respair commited on
Commit
ffaf0d2
·
verified ·
1 Parent(s): ef029bd

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +51 -0
  2. __pycache__/discriminator.cpython-312.pyc +0 -0
  3. __pycache__/higgs_audio_tokenizer.cpython-311.pyc +0 -0
  4. __pycache__/higgs_audio_tokenizer.cpython-312.pyc +0 -0
  5. __pycache__/loss.cpython-312.pyc +0 -0
  6. __pycache__/semantic_module.cpython-312.pyc +0 -0
  7. boson_codeit.py +651 -0
  8. descriptaudiocodec/__init__.py +0 -0
  9. descriptaudiocodec/__pycache__/__init__.cpython-311.pyc +0 -0
  10. descriptaudiocodec/__pycache__/__init__.cpython-312.pyc +0 -0
  11. descriptaudiocodec/dac/model/__pycache__/base.cpython-311.pyc +0 -0
  12. descriptaudiocodec/dac/model/__pycache__/base.cpython-312.pyc +0 -0
  13. descriptaudiocodec/dac/model/__pycache__/dac.cpython-311.pyc +0 -0
  14. descriptaudiocodec/dac/model/__pycache__/dac.cpython-312.pyc +0 -0
  15. descriptaudiocodec/dac/model/base.py +286 -0
  16. descriptaudiocodec/dac/model/dac.py +365 -0
  17. descriptaudiocodec/dac/nn/layers.py +33 -0
  18. descriptaudiocodec/dac/nn/quantize.py +251 -0
  19. discriminator.py +596 -0
  20. higgs_audio_tokenizer.py +373 -0
  21. loss.py +368 -0
  22. outputs/logs/250801-104649/events.out.tfevents.1754045209.192-222-50-191.575849.0 +3 -0
  23. outputs/logs/250801-104824/events.out.tfevents.1754045304.192-222-50-191.577752.0 +3 -0
  24. outputs/logs/250801-104944/events.out.tfevents.1754045384.192-222-50-191.579650.0 +3 -0
  25. outputs/logs/250801-105034/events.out.tfevents.1754045434.192-222-50-191.581483.0 +3 -0
  26. outputs/logs/250801-105133/events.out.tfevents.1754045493.192-222-50-191.583409.0 +3 -0
  27. outputs/logs/250801-134657/events.out.tfevents.1754056017.192-222-50-191.688744.0 +3 -0
  28. outputs/logs/250801-135301/events.out.tfevents.1754056381.192-222-50-191.693590.0 +3 -0
  29. outputs/logs/250801-135344/events.out.tfevents.1754056424.192-222-50-191.695388.0 +3 -0
  30. outputs/logs/250801-135510/events.out.tfevents.1754056510.192-222-50-191.697490.0 +3 -0
  31. outputs/logs/250801-202235/events.out.tfevents.1754079755.192-222-50-191.6026.0 +3 -0
  32. outputs/logs/250801-202320/events.out.tfevents.1754079800.192-222-50-191.6708.0 +3 -0
  33. outputs/logs/250802-065733/events.out.tfevents.1754117853.192-222-50-191.86944.0 +3 -0
  34. outputs/logs/250802-072035/events.out.tfevents.1754119235.192-222-50-191.100690.0 +3 -0
  35. outputs_24/logs/250730-112649/events.out.tfevents.1753874809.192-222-50-191.3556345.0 +3 -0
  36. outputs_24/logs/250730-112910/events.out.tfevents.1753874950.192-222-50-191.3557426.0 +3 -0
  37. outputs_24/logs/250730-113135/events.out.tfevents.1753875095.192-222-50-191.3558918.0 +3 -0
  38. outputs_24/logs/250730-114727/events.out.tfevents.1753876047.192-222-50-191.3567432.0 +3 -0
  39. outputs_24/logs/250730-115006/events.out.tfevents.1753876206.192-222-50-191.3569242.0 +3 -0
  40. outputs_24/logs/250730-151325/events.out.tfevents.1753888405.192-222-50-191.3660307.0 +3 -0
  41. outputs_24/logs/250730-152054/events.out.tfevents.1753888854.192-222-50-191.3663830.0 +3 -0
  42. outputs_24/logs/250730-152132/events.out.tfevents.1753888892.192-222-50-191.3664702.0 +3 -0
  43. outputs_24/logs/250730-152218/events.out.tfevents.1753888938.192-222-50-191.3665630.0 +3 -0
  44. outputs_24/logs/250730-152329/events.out.tfevents.1753889009.192-222-50-191.3666743.0 +3 -0
  45. outputs_24/logs/250730-152554/events.out.tfevents.1753889154.192-222-50-191.3668339.0 +3 -0
  46. outputs_24/logs/250730-152702/events.out.tfevents.1753889222.192-222-50-191.3669391.0 +3 -0
  47. outputs_24/logs/250730-152902/events.out.tfevents.1753889342.192-222-50-191.3671654.0 +3 -0
  48. outputs_24/logs/250730-161025/events.out.tfevents.1753891825.192-222-50-191.3698156.0 +3 -0
  49. outputs_24/logs/250730-165034/events.out.tfevents.1753894234.192-222-50-191.3717308.0 +3 -0
  50. outputs_24/logs/250730-165327/events.out.tfevents.1753894407.192-222-50-191.3719515.0 +3 -0
LICENSE ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Third-Party License Attribution for Audio Processing Module
2
+ ===========================================================
3
+
4
+ This directory contains code derived from multiple open-source projects.
5
+ The following sections detail the licenses and attributions for third-party code.
6
+
7
+ ## XCodec Repository
8
+ The code in this directory is derived from:
9
+ https://github.com/zhenye234/xcodec
10
+
11
+ ## Individual File Attributions
12
+
13
+ ### Quantization Module (quantization/)
14
+ - Several files contain code derived from Meta Platforms, Inc. and the vector-quantize-pytorch repository
15
+ - Individual files contain their own license headers where applicable
16
+ - The vector-quantize-pytorch portions are licensed under the MIT License
17
+
18
+ ## License Terms
19
+
20
+ ### MIT License (for applicable portions)
21
+ Permission is hereby granted, free of charge, to any person obtaining a copy
22
+ of this software and associated documentation files (the "Software"), to deal
23
+ in the Software without restriction, including without limitation the rights
24
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
25
+ copies of the Software, and to permit persons to whom the Software is
26
+ furnished to do so, subject to the following conditions:
27
+
28
+ The above copyright notice and this permission notice shall be included in all
29
+ copies or substantial portions of the Software.
30
+
31
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
32
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
33
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
34
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
35
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
36
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
37
+ SOFTWARE.
38
+
39
+ ## Attribution Requirements
40
+ When using this code, please ensure proper attribution to:
41
+ 1. The original xcodec repository: https://github.com/zhenye234/xcodec
42
+ 2. Any other repositories mentioned in individual file headers
43
+ 3. This derivative work and its modifications
44
+
45
+ ## Disclaimer
46
+ This directory contains modified versions of the original code. Please refer to
47
+ the original repositories for the canonical implementations and their specific
48
+ license terms.
49
+
50
+ For any questions about licensing or attribution, please check the individual
51
+ file headers and the original source repositories.
__pycache__/discriminator.cpython-312.pyc ADDED
Binary file (18.8 kB). View file
 
__pycache__/higgs_audio_tokenizer.cpython-311.pyc ADDED
Binary file (21.2 kB). View file
 
__pycache__/higgs_audio_tokenizer.cpython-312.pyc ADDED
Binary file (19.8 kB). View file
 
__pycache__/loss.cpython-312.pyc ADDED
Binary file (16.4 kB). View file
 
__pycache__/semantic_module.cpython-312.pyc ADDED
Binary file (10.6 kB). View file
 
boson_codeit.py ADDED
@@ -0,0 +1,651 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # #!/usr/bin/env python3
2
+ # """
3
+ # Audio Processing Script for Boson Codes
4
+ # Processes audio files in parallel using Higgs Audio Tokenizer
5
+ # and saves encoded representations as .pt files.
6
+ # """
7
+
8
+ # import os
9
+ # import sys
10
+ # import json
11
+ # import torch
12
+ # import librosa
13
+ # import numpy as np
14
+ # import warnings
15
+ # import argparse
16
+ # from pathlib import Path
17
+ # from multiprocessing import Pool
18
+ # from tqdm import tqdm
19
+
20
+ # from datasets import load_from_disk
21
+ # from higgs_audio_tokenizer import HiggsAudioTokenizer
22
+
23
+ # # Suppress PyTorch FutureWarnings
24
+ # warnings.filterwarnings("ignore", category=FutureWarning)
25
+
26
+ # # Global configuration
27
+ # DEFAULT_OUTPUT_DIR = "/home/ubuntu/boson_codes"
28
+ # DEFAULT_NUM_CORES = 48
29
+ # DEFAULT_SAMPLE_RATE = 44100
30
+ # DEFAULT_DATASET_PATH = "/home/ubuntu/ttsar/Layla/src_bpe_2/data"
31
+
32
+ # # Model paths
33
+ # CONFIG_PATH = "/home/ubuntu/.cache/huggingface/hub/models--bosonai--higgs-audio-v2-tokenizer/snapshots/9d4988fbd4ad07b4cac3a5fa462741a41810dbec/config.json"
34
+ # MODEL_PATH = "/home/ubuntu/.cache/huggingface/hub/models--bosonai--higgs-audio-v2-tokenizer/snapshots/9d4988fbd4ad07b4cac3a5fa462741a41810dbec/model.pth"
35
+
36
+ # # Global model variable (initialized in each worker)
37
+ # model = None
38
+
39
+
40
+ # def init_worker():
41
+ # """Initialize model once per worker process."""
42
+ # global model
43
+ # device = 'cpu'
44
+
45
+ # # Load config
46
+ # with open(CONFIG_PATH, 'r') as f:
47
+ # config = json.load(f)
48
+
49
+ # # Initialize model
50
+ # model = HiggsAudioTokenizer(
51
+ # **config,
52
+ # device=device,
53
+ # )
54
+
55
+ # # Load weights
56
+ # parameter_dict = torch.load(MODEL_PATH, map_location=device)
57
+ # _ = model.load_state_dict(parameter_dict, strict=False)
58
+ # model = model.to(device)
59
+ # _ = model.eval()
60
+
61
+ # print(f"Model loaded in worker {os.getpid()}")
62
+
63
+
64
+ # def process_audio_file(args):
65
+ # """Process a single audio file using pre-loaded model."""
66
+ # filename, output_dir, sample_rate = args
67
+
68
+ # try:
69
+ # # Output filename - same name, just change extension to .pt
70
+ # base_name = Path(filename).stem
71
+ # output_path = os.path.join(output_dir, f"{base_name}.pt")
72
+
73
+ # # Skip if exists (double-check in case of race conditions)
74
+ # if os.path.exists(output_path):
75
+ # return ("skipped", filename)
76
+
77
+ # # Load and process audio
78
+ # wav, sr = librosa.load(filename, sr=sample_rate)
79
+ # wav = torch.from_numpy(wav).unsqueeze(0).float().to('cpu')
80
+
81
+ # # Encode using the pre-loaded model
82
+ # with torch.no_grad():
83
+ # encoded = model._xcodec_encode(wav.unsqueeze(0))
84
+
85
+ # # Save codes only
86
+ # torch.save(encoded.audio_codes, output_path)
87
+
88
+ # return ("success", filename)
89
+
90
+ # except Exception as e:
91
+ # return ("error", filename, str(e))
92
+
93
+
94
+ # def load_dataset(dataset_path):
95
+ # """Load and prepare the dataset."""
96
+ # print(f"Loading dataset from: {dataset_path}")
97
+ # ds = load_from_disk(dataset_path)
98
+ # print(f"Dataset info: {ds}")
99
+
100
+ # # Remove unnecessary columns
101
+ # columns_to_remove = ['spk', 'duration', 'codes', 'input_ids', 'attention_mask']
102
+ # existing_columns = [col for col in columns_to_remove if col in ds.column_names]
103
+ # if existing_columns:
104
+ # ds = ds.remove_columns(existing_columns)
105
+ # print(f"Removed columns: {existing_columns}")
106
+
107
+ # # Convert to pandas DataFrame
108
+ # df = ds.to_pandas()
109
+ # print(f"Loaded {len(df)} files from dataset")
110
+ # return df
111
+
112
+
113
+ # def main(args):
114
+ # """Main processing function."""
115
+ # # Change to audio processing directory
116
+ # os.chdir("/home/ubuntu/ttsar/boson_audio_codec/audio_processing")
117
+ # print(f"Working directory: {os.getcwd()}")
118
+
119
+ # # Create output directory
120
+ # os.makedirs(args.output_dir, exist_ok=True)
121
+ # print(f"Output directory: {args.output_dir}")
122
+
123
+ # # Check if model files exist
124
+ # if not os.path.exists(CONFIG_PATH):
125
+ # print(f"Error: Config file not found at {CONFIG_PATH}")
126
+ # sys.exit(1)
127
+ # if not os.path.exists(MODEL_PATH):
128
+ # print(f"Error: Model file not found at {MODEL_PATH}")
129
+ # sys.exit(1)
130
+
131
+ # # Load dataset
132
+ # df = load_dataset(args.dataset_path)
133
+
134
+ # # Get filenames from dataframe
135
+ # all_filenames = df['filename'].tolist()
136
+
137
+ # # Pre-filter to exclude already processed files
138
+ # filenames_to_process = []
139
+ # already_processed = []
140
+
141
+ # print(f"\nChecking for already processed files...")
142
+ # for filename in all_filenames:
143
+ # base_name = Path(filename).stem
144
+ # output_path = os.path.join(args.output_dir, f"{base_name}.pt")
145
+ # if os.path.exists(output_path):
146
+ # already_processed.append(filename)
147
+ # else:
148
+ # filenames_to_process.append(filename)
149
+
150
+ # print(f"\nTotal files: {len(all_filenames)}")
151
+ # print(f"Already processed: {len(already_processed)}")
152
+ # print(f"To process: {len(filenames_to_process)}")
153
+
154
+ # if len(filenames_to_process) == 0:
155
+ # print("\nAll files have already been processed!")
156
+ # return
157
+
158
+ # print(f"\nProcessing {len(filenames_to_process)} files using {args.num_cores} cores...")
159
+ # print(f"Sample rate: {args.sample_rate} Hz")
160
+
161
+ # # Prepare arguments for multiprocessing
162
+ # process_args = [(filename, args.output_dir, args.sample_rate)
163
+ # for filename in filenames_to_process]
164
+
165
+ # # Process in parallel with model reuse
166
+ # with Pool(processes=args.num_cores, initializer=init_worker) as pool:
167
+ # results = list(tqdm(
168
+ # pool.imap(process_audio_file, process_args, chunksize=args.chunksize),
169
+ # total=len(filenames_to_process),
170
+ # desc="Processing audio files"
171
+ # ))
172
+
173
+ # # Count results
174
+ # processed = sum(1 for r in results if r[0] == "success")
175
+ # skipped = sum(1 for r in results if r[0] == "skipped")
176
+ # errors = sum(1 for r in results if r[0] == "error")
177
+
178
+ # print(f"\nProcessing complete!")
179
+ # print(f" Successfully processed: {processed}")
180
+ # print(f" Previously processed: {len(already_processed)}")
181
+ # print(f" Skipped (race condition): {skipped}")
182
+ # print(f" Errors: {errors}")
183
+
184
+ # # Show errors if any
185
+ # if errors > 0:
186
+ # print("\nErrors encountered:")
187
+ # error_log_path = os.path.join(args.output_dir, "processing_errors.log")
188
+ # with open(error_log_path, 'w') as f:
189
+ # for r in results:
190
+ # if r[0] == "error":
191
+ # error_msg = f"{r[1]}: {r[2]}"
192
+ # print(f" {error_msg}")
193
+ # f.write(error_msg + "\n")
194
+ # print(f"\nError log saved to: {error_log_path}")
195
+
196
+ # # Show summary of all processed files
197
+ # total_processed_files = len(list(Path(args.output_dir).glob("*.pt")))
198
+ # print(f"\nTotal .pt files in {args.output_dir}: {total_processed_files}")
199
+
200
+
201
+ # if __name__ == "__main__":
202
+ # parser = argparse.ArgumentParser(
203
+ # description="Process audio files using Higgs Audio Tokenizer and save as .pt files"
204
+ # )
205
+
206
+ # parser.add_argument(
207
+ # "--dataset-path",
208
+ # type=str,
209
+ # default=DEFAULT_DATASET_PATH,
210
+ # help=f"Path to the dataset (default: {DEFAULT_DATASET_PATH})"
211
+ # )
212
+
213
+ # parser.add_argument(
214
+ # "--output-dir",
215
+ # type=str,
216
+ # default=DEFAULT_OUTPUT_DIR,
217
+ # help=f"Output directory for .pt files (default: {DEFAULT_OUTPUT_DIR})"
218
+ # )
219
+
220
+ # parser.add_argument(
221
+ # "--num-cores",
222
+ # type=int,
223
+ # default=DEFAULT_NUM_CORES,
224
+ # help=f"Number of CPU cores to use (default: {DEFAULT_NUM_CORES})"
225
+ # )
226
+
227
+ # parser.add_argument(
228
+ # "--sample-rate",
229
+ # type=int,
230
+ # default=DEFAULT_SAMPLE_RATE,
231
+ # help=f"Sample rate for audio processing (default: {DEFAULT_SAMPLE_RATE})"
232
+ # )
233
+
234
+ # parser.add_argument(
235
+ # "--chunksize",
236
+ # type=int,
237
+ # default=1,
238
+ # help="Chunksize for multiprocessing pool (default: 1)"
239
+ # )
240
+
241
+ # args = parser.parse_args()
242
+
243
+ # # Run main processing
244
+ # try:
245
+ # main(args)
246
+ # except KeyboardInterrupt:
247
+ # print("\n\nProcessing interrupted by user")
248
+ # sys.exit(1)
249
+ # except Exception as e:
250
+ # print(f"\n\nError: {e}")
251
+ # sys.exit(1)
252
+
253
+ #!/usr/bin/env python3
254
+ """
255
+ GPU Batch Processing Script for Boson Codes with Dataset Loading
256
+ """
257
+
258
+ import os
259
+ import sys
260
+ import json
261
+ import torch
262
+ import torch.nn.functional as F
263
+ import librosa
264
+ import numpy as np
265
+ from pathlib import Path
266
+ from tqdm import tqdm
267
+ import warnings
268
+ from torch.nn.utils import remove_weight_norm, weight_norm
269
+
270
+
271
+ # from boson_multimodal.audio_processing.higgs_audio_tokenizer import load_higgs_audio_tokenizer
272
+ # model = load_higgs_audio_tokenizer("bosonai/higgs-audio-v2-tokenizer")
273
+ import librosa
274
+ import torch
275
+ import torch.nn.functional as F
276
+ import numpy as np
277
+ import json
278
+ import torch
279
+
280
+ from higgs_audio_tokenizer import HiggsAudioTokenizer
281
+ # model = load_higgs_audio_tokenizer("bosonai/higgs-audio-v2-tokenizer")
282
+
283
+ import torch
284
+ import torch.nn as nn
285
+ import warnings
286
+
287
+ # Suppress warnings
288
+ warnings.filterwarnings('ignore')
289
+
290
+ def remove_weight_norms_from_model(model):
291
+ for module in model.modules():
292
+ try:
293
+ remove_weight_norm(module)
294
+ except:
295
+ continue
296
+ return model
297
+
298
+
299
+ class EncodedResult:
300
+ def __init__(self, audio_codes):
301
+ self.audio_codes = audio_codes
302
+
303
+ def encode_batch(model, x_batch):
304
+ """
305
+ Encodes a batch of audio tensors using the HiggsAudioTokenizer model.
306
+ Args:
307
+ model: The loaded HiggsAudioTokenizer model.
308
+ x_batch: A tensor of shape [B, 1, T]
309
+ """
310
+ # Acoustic and Semantic Feature Extraction
311
+ e_semantic_input = model.get_regress_target(x_batch).detach()
312
+ e_semantic = model.encoder_semantic(e_semantic_input.transpose(1, 2))
313
+ e_acoustic = model.encoder(x_batch)
314
+
315
+ # This block contains the fix for batch processing
316
+ if e_acoustic.shape[2] != e_semantic.shape[2]:
317
+ pad_size = 160 * model.semantic_downsample_factor
318
+
319
+ # 1. Remove channel dim, preserving batch dim -> [B, T]
320
+ x_slice = x_batch[:, 0, :]
321
+
322
+ # 2. Pad the tensor
323
+ x_padded = F.pad(x_slice, (pad_size, pad_size))
324
+
325
+ # 3. Re-add channel dim before passing to encoder -> [B, 1, T_padded]
326
+ e_acoustic = model.encoder(x_padded.unsqueeze(1))
327
+
328
+ # Ensure dimensions match before concatenating
329
+ min_len = min(e_acoustic.shape[2], e_semantic.shape[2])
330
+ e_acoustic = e_acoustic[:, :, :min_len]
331
+ e_semantic = e_semantic[:, :, :min_len]
332
+
333
+ # Remainder of the original encoding logic
334
+ e = torch.cat([e_acoustic, e_semantic], dim=1)
335
+ e = model.fc_prior(e.transpose(1, 2))
336
+
337
+ if model.quantizer_type == "RVQ":
338
+ e = e.transpose(1, 2)
339
+ _, codes, _, _ = model.quantizer(e, model.frame_rate, None)
340
+ codes = codes.permute(1, 0, 2)
341
+ else: # RFSQ
342
+ quantized, codes = model.quantizer(e)
343
+ codes = codes.permute(0, 2, 1)
344
+
345
+ return EncodedResult(audio_codes=codes)
346
+
347
+
348
+ def fix_all_inference_issues(model):
349
+ """
350
+ Comprehensive fix for all potential inference issues
351
+ """
352
+ device = next(model.parameters()).device
353
+
354
+ # 1. Force everything to eval mode
355
+ model.eval()
356
+ with torch.no_grad():
357
+ for module in model.modules():
358
+ if isinstance(module, nn.Module):
359
+ module.eval()
360
+ if hasattr(module, 'training'):
361
+ module.training = False
362
+
363
+ # 2. Fix semantic model specifically
364
+ if hasattr(model, 'semantic_model'):
365
+ print("Fixing semantic model...")
366
+
367
+ # Move to correct device
368
+ model.semantic_model = model.semantic_model.to(device)
369
+ model.semantic_model.eval()
370
+
371
+ # Disable ALL gradient checkpointing
372
+ def disable_gradient_checkpointing(module):
373
+ if hasattr(module, 'gradient_checkpointing'):
374
+ module.gradient_checkpointing = False
375
+ if hasattr(module, 'gradient_checkpointing_disable'):
376
+ try:
377
+ module.gradient_checkpointing_disable()
378
+ except:
379
+ pass
380
+ for child in module.children():
381
+ disable_gradient_checkpointing(child)
382
+
383
+ disable_gradient_checkpointing(model.semantic_model)
384
+
385
+ # For HuBERT specifically
386
+ if hasattr(model.semantic_model, 'encoder'):
387
+ model.semantic_model.encoder.gradient_checkpointing = False
388
+ if hasattr(model.semantic_model.encoder, 'layers'):
389
+ for layer in model.semantic_model.encoder.layers:
390
+ if hasattr(layer, 'gradient_checkpointing'):
391
+ layer.gradient_checkpointing = False
392
+
393
+ # 3. Set all dropout to eval mode
394
+ def set_dropout_eval(module):
395
+ if isinstance(module, nn.Dropout):
396
+ module.eval()
397
+ module.training = False
398
+ for child in module.children():
399
+ set_dropout_eval(child)
400
+
401
+ set_dropout_eval(model)
402
+
403
+ # 4. Clear any cached computations
404
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
405
+
406
+ return model
407
+
408
+ def inference_pipeline(checkpoint_path, config_path, device='cuda'):
409
+ """
410
+ Complete pipeline for inference with your trained model
411
+ """
412
+ # Load config
413
+ print("Loading config...")
414
+ with open(config_path, 'r') as f:
415
+ config = json.load(f)
416
+
417
+ # Create model
418
+ print("Creating model...")
419
+ model = HiggsAudioTokenizer(
420
+ n_filters=config['n_filters'],
421
+ D=config['D'],
422
+ target_bandwidths=config['target_bandwidths'],
423
+ ratios=config['ratios'],
424
+ sample_rate=config['sample_rate'],
425
+ bins=config['bins'],
426
+ n_q=config['n_q'],
427
+ codebook_dim=config.get('codebook_dim', None),
428
+ semantic_techer=config['semantic_techer'],
429
+ device=device
430
+ ).to(device)
431
+
432
+ # Load checkpoint
433
+ print("Loading checkpoint...")
434
+ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
435
+
436
+ if 'model_state_dict' in checkpoint:
437
+ state_dict = checkpoint['model_state_dict']
438
+ else:
439
+ state_dict = checkpoint
440
+
441
+ # Remove 'module.' prefix if present (from DDP)
442
+ new_state_dict = {}
443
+ for k, v in state_dict.items():
444
+ if k.startswith('module.'):
445
+ new_state_dict[k[7:]] = v
446
+ else:
447
+ new_state_dict[k] = v
448
+
449
+ model.load_state_dict(new_state_dict, strict=False)
450
+
451
+ # Fix all inference issues
452
+ print("Fixing inference issues...")
453
+ model = fix_all_inference_issues(model)
454
+
455
+
456
+ return model
457
+
458
+
459
+
460
+ # # Add paths
461
+ # sys.path.insert(0, "/home/ubuntu/AP-BWE")
462
+
463
+ # Suppress warnings
464
+ warnings.filterwarnings("ignore")
465
+
466
+ # Configuration
467
+ OUTPUT_DIR = "/home/ubuntu/data_boson_44.1khz"
468
+ BATCH_SIZE = 32
469
+ SAMPLE_RATE = 44100
470
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
471
+ DATASET_PATH = "/home/ubuntu/ttsar/Layla/src_bpe_2/Qanary_data"
472
+
473
+ # # Model paths
474
+ # CONFIG_PATH = "/home/ubuntu/.cache/huggingface/hub/models--bosonai--higgs-audio-v2-tokenizer/snapshots/9d4988fbd4ad07b4cac3a5fa462741a41810dbec/config.json"
475
+ # MODEL_PATH = "/home/ubuntu/.cache/huggingface/hub/models--bosonai--higgs-audio-v2-tokenizer/snapshots/9d4988fbd4ad07b4cac3a5fa462741a41810dbec/model.pth"
476
+
477
+ # --- Setup ---
478
+ print(f"Using device: {DEVICE}")
479
+
480
+ # Change to working directory
481
+ os.chdir("/home/ubuntu/ttsar/boson_audio_codec/audio_processing")
482
+
483
+ # Load dataset
484
+ from datasets import load_from_disk
485
+
486
+
487
+ print(f"Loading dataset from: {DATASET_PATH}")
488
+ ds = load_from_disk(DATASET_PATH)
489
+ print(f"Dataset info: {ds}")
490
+
491
+ # Remove unnecessary columns
492
+ columns_to_remove = ['spk', 'duration', 'codes', 'input_ids', 'attention_mask']
493
+ existing_columns = [col for col in columns_to_remove if col in ds.column_names]
494
+ if existing_columns:
495
+ ds = ds.remove_columns(existing_columns)
496
+
497
+ df = ds.to_pandas()
498
+ print(f"Loaded {len(df)} files from dataset")
499
+
500
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
501
+ print(f"Output directory '{OUTPUT_DIR}' is ready.")
502
+
503
+ # --- Filter already processed ---
504
+ print("Checking for already processed files...")
505
+
506
+ def get_output_path(audio_path):
507
+ base_name = Path(audio_path).stem
508
+ return os.path.join(OUTPUT_DIR, f"{base_name}.pt")
509
+
510
+ # Filter
511
+ original_count = len(df)
512
+ df['output_exists'] = df['filename'].apply(lambda x: os.path.exists(get_output_path(x)))
513
+ df_filtered = df[~df['output_exists']].copy()
514
+ skipped_count = original_count - len(df_filtered)
515
+
516
+ print(f"Found {skipped_count} already processed files. Skipping them.")
517
+ print(f"Processing {len(df_filtered)} remaining files.")
518
+
519
+ if len(df_filtered) == 0:
520
+ print("All files have already been processed!")
521
+ exit()
522
+
523
+ # --- Load Model ---
524
+ print("Loading Higgs Audio Tokenizer model...")
525
+
526
+ from transformers import HubertModel
527
+ from higgs_audio_tokenizer import HiggsAudioTokenizer
528
+
529
+ # Load config
530
+ # with open(CONFIG_PATH, 'r') as f:
531
+ # config = json.load(f)
532
+
533
+ # # Initialize model
534
+ # model = HiggsAudioTokenizer(
535
+ # **config,
536
+ # device=DEVICE,
537
+ # )
538
+
539
+ # Load weights
540
+ # parameter_dict = torch.load(MODEL_PATH, map_location=DEVICE)
541
+ # _ = model.load_state_dict(parameter_dict, strict=False)
542
+ # model = model.to(DEVICE)
543
+ # _ = model.eval()
544
+
545
+
546
+ checkpoint_path = '/home/ubuntu/ttsar/boson_audio_codec/audio_processing/outputs_CQT/checkpoints/step_99000.pth'
547
+ config_path = '/home/ubuntu/ttsar/boson_audio_codec/audio_processing/config copy.json'
548
+ device = 'cuda'
549
+ model = inference_pipeline(checkpoint_path, config_path, device)
550
+ _ = model.eval()
551
+
552
+ model = remove_weight_norms_from_model(model)
553
+
554
+ print(f"Model loaded on {DEVICE}")
555
+
556
+ # Get hop length
557
+ hop_length = model.hop_length
558
+ print(f"Encoder hop length: {hop_length}")
559
+
560
+ # --- Batch Processing ---
561
+ print(f"\nStarting batch processing with batch size {BATCH_SIZE}...")
562
+
563
+ # Process in batches
564
+ filenames = df_filtered['filename'].tolist()
565
+ total_processed = 0
566
+ total_errors = 0
567
+
568
+ with torch.no_grad():
569
+ for batch_start in tqdm(range(0, len(filenames), BATCH_SIZE), desc="Processing batches"):
570
+ batch_end = min(batch_start + BATCH_SIZE, len(filenames))
571
+ batch_filenames = filenames[batch_start:batch_end]
572
+
573
+ batch_audio = []
574
+ batch_lengths = []
575
+ batch_outputs = []
576
+
577
+ # Load batch
578
+ for filename in batch_filenames:
579
+ output_path = get_output_path(filename)
580
+
581
+ # Skip if exists (race condition check)
582
+ if os.path.exists(output_path):
583
+ continue
584
+
585
+ try:
586
+ # Load audio
587
+ wav, _ = librosa.load(filename, sr=SAMPLE_RATE)
588
+ wav_tensor = torch.from_numpy(wav).float()
589
+
590
+ batch_audio.append(wav_tensor)
591
+ batch_lengths.append(len(wav))
592
+ batch_outputs.append(output_path)
593
+
594
+ except Exception as e:
595
+ print(f"\nError loading {filename}: {e}")
596
+ total_errors += 1
597
+ continue
598
+
599
+ if not batch_audio:
600
+ continue
601
+
602
+ # Pad batch to same length
603
+ max_len = max(len(x) for x in batch_audio)
604
+ padded_batch = []
605
+
606
+ for audio in batch_audio:
607
+ pad_len = max_len - len(audio)
608
+ if pad_len > 0:
609
+ audio = F.pad(audio, (0, pad_len), mode='constant', value=0)
610
+ # Don't add extra dimensions here, just collect the padded audio
611
+ padded_batch.append(audio)
612
+
613
+ # Convert list to tensor and add channel dimension
614
+ # Stack along batch dimension to get [B, T]
615
+ batch_tensor = torch.stack(padded_batch, dim=0) # [B, T]
616
+ # Add channel dimension
617
+ batch_tensor = batch_tensor.unsqueeze(1) # [B, 1, T]
618
+ batch_tensor = batch_tensor.to(DEVICE)
619
+
620
+ # Encode batch
621
+ try:
622
+ encoded = encode_batch(model, batch_tensor)
623
+ codes = encoded.audio_codes # [B, n_codebooks, T_compressed]
624
+
625
+ # Save each item
626
+ for idx, (output_path, orig_len) in enumerate(zip(batch_outputs, batch_lengths)):
627
+ # Calculate true code length
628
+ true_code_len = int(np.ceil(orig_len / hop_length))
629
+
630
+ # Extract non-padded codes
631
+ item_codes = codes[idx, :, :true_code_len].cpu()
632
+
633
+ # Save
634
+ torch.save(item_codes, output_path)
635
+ total_processed += 1
636
+
637
+ except Exception as e:
638
+ print(f"\nError encoding batch: {e}")
639
+ total_errors += len(batch_outputs)
640
+
641
+ print("\n" + "="*50)
642
+ print("PROCESSING COMPLETE!")
643
+ print("="*50)
644
+ print(f"Successfully processed: {total_processed} files")
645
+ print(f"Previously processed: {skipped_count} files")
646
+ print(f"Errors encountered: {total_errors} files")
647
+ print(f"Output directory: {OUTPUT_DIR}")
648
+
649
+ # Final count
650
+ final_count = len(list(Path(OUTPUT_DIR).glob("*.pt")))
651
+ print(f"Total .pt files in output: {final_count}")
descriptaudiocodec/__init__.py ADDED
File without changes
descriptaudiocodec/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (193 Bytes). View file
 
descriptaudiocodec/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (181 Bytes). View file
 
descriptaudiocodec/dac/model/__pycache__/base.cpython-311.pyc ADDED
Binary file (13.6 kB). View file
 
descriptaudiocodec/dac/model/__pycache__/base.cpython-312.pyc ADDED
Binary file (12.9 kB). View file
 
descriptaudiocodec/dac/model/__pycache__/dac.cpython-311.pyc ADDED
Binary file (17.6 kB). View file
 
descriptaudiocodec/dac/model/__pycache__/dac.cpython-312.pyc ADDED
Binary file (15.5 kB). View file
 
descriptaudiocodec/dac/model/base.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from pathlib import Path
4
+ from typing import Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import tqdm
9
+ from audiotools import AudioSignal
10
+ from torch import nn
11
+
12
+ SUPPORTED_VERSIONS = ["1.0.0"]
13
+
14
+
15
+ @dataclass
16
+ class DACFile:
17
+ codes: torch.Tensor
18
+
19
+ # Metadata
20
+ chunk_length: int
21
+ original_length: int
22
+ input_db: float
23
+ channels: int
24
+ sample_rate: int
25
+ padding: bool
26
+ dac_version: str
27
+
28
+ def save(self, path):
29
+ artifacts = {
30
+ "codes": self.codes.numpy().astype(np.uint16),
31
+ "metadata": {
32
+ "input_db": self.input_db.numpy().astype(np.float32),
33
+ "original_length": self.original_length,
34
+ "sample_rate": self.sample_rate,
35
+ "chunk_length": self.chunk_length,
36
+ "channels": self.channels,
37
+ "padding": self.padding,
38
+ "dac_version": SUPPORTED_VERSIONS[-1],
39
+ },
40
+ }
41
+ path = Path(path).with_suffix(".dac")
42
+ with open(path, "wb") as f:
43
+ np.save(f, artifacts)
44
+ return path
45
+
46
+ @classmethod
47
+ def load(cls, path):
48
+ artifacts = np.load(path, allow_pickle=True)[()]
49
+ codes = torch.from_numpy(artifacts["codes"].astype(int))
50
+ if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS:
51
+ raise RuntimeError(f"Given file {path} can't be loaded with this version of descript-audio-codec.")
52
+ return cls(codes=codes, **artifacts["metadata"])
53
+
54
+
55
+ class CodecMixin:
56
+ @property
57
+ def padding(self):
58
+ if not hasattr(self, "_padding"):
59
+ self._padding = True
60
+ return self._padding
61
+
62
+ @padding.setter
63
+ def padding(self, value):
64
+ assert isinstance(value, bool)
65
+
66
+ layers = [l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d))]
67
+
68
+ for layer in layers:
69
+ if value:
70
+ if hasattr(layer, "original_padding"):
71
+ layer.padding = layer.original_padding
72
+ else:
73
+ layer.original_padding = layer.padding
74
+ layer.padding = tuple(0 for _ in range(len(layer.padding)))
75
+
76
+ self._padding = value
77
+
78
+ def get_delay(self):
79
+ # Any number works here, delay is invariant to input length
80
+ l_out = self.get_output_length(0)
81
+ L = l_out
82
+
83
+ layers = []
84
+ for layer in self.modules():
85
+ if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
86
+ layers.append(layer)
87
+
88
+ for layer in reversed(layers):
89
+ d = layer.dilation[0]
90
+ k = layer.kernel_size[0]
91
+ s = layer.stride[0]
92
+
93
+ if isinstance(layer, nn.ConvTranspose1d):
94
+ L = ((L - d * (k - 1) - 1) / s) + 1
95
+ elif isinstance(layer, nn.Conv1d):
96
+ L = (L - 1) * s + d * (k - 1) + 1
97
+
98
+ L = math.ceil(L)
99
+
100
+ l_in = L
101
+
102
+ return (l_in - l_out) // 2
103
+
104
+ def get_output_length(self, input_length):
105
+ L = input_length
106
+ # Calculate output length
107
+ for layer in self.modules():
108
+ if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
109
+ d = layer.dilation[0]
110
+ k = layer.kernel_size[0]
111
+ s = layer.stride[0]
112
+
113
+ if isinstance(layer, nn.Conv1d):
114
+ L = ((L - d * (k - 1) - 1) / s) + 1
115
+ elif isinstance(layer, nn.ConvTranspose1d):
116
+ L = (L - 1) * s + d * (k - 1) + 1
117
+
118
+ L = math.floor(L)
119
+ return L
120
+
121
+ @torch.no_grad()
122
+ def compress(
123
+ self,
124
+ audio_path_or_signal: Union[str, Path, AudioSignal],
125
+ win_duration: float = 1.0,
126
+ verbose: bool = False,
127
+ normalize_db: float = -16,
128
+ n_quantizers: int = None,
129
+ ) -> DACFile:
130
+ """Processes an audio signal from a file or AudioSignal object into
131
+ discrete codes. This function processes the signal in short windows,
132
+ using constant GPU memory.
133
+
134
+ Parameters
135
+ ----------
136
+ audio_path_or_signal : Union[str, Path, AudioSignal]
137
+ audio signal to reconstruct
138
+ win_duration : float, optional
139
+ window duration in seconds, by default 5.0
140
+ verbose : bool, optional
141
+ by default False
142
+ normalize_db : float, optional
143
+ normalize db, by default -16
144
+
145
+ Returns
146
+ -------
147
+ DACFile
148
+ Object containing compressed codes and metadata
149
+ required for decompression
150
+ """
151
+ audio_signal = audio_path_or_signal
152
+ if isinstance(audio_signal, (str, Path)):
153
+ audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal))
154
+
155
+ self.eval()
156
+ original_padding = self.padding
157
+ original_device = audio_signal.device
158
+
159
+ audio_signal = audio_signal.clone()
160
+ original_sr = audio_signal.sample_rate
161
+
162
+ resample_fn = audio_signal.resample
163
+ loudness_fn = audio_signal.loudness
164
+
165
+ # If audio is > 10 minutes long, use the ffmpeg versions
166
+ if audio_signal.signal_duration >= 10 * 60 * 60:
167
+ resample_fn = audio_signal.ffmpeg_resample
168
+ loudness_fn = audio_signal.ffmpeg_loudness
169
+
170
+ original_length = audio_signal.signal_length
171
+ resample_fn(self.sample_rate)
172
+ input_db = loudness_fn()
173
+
174
+ if normalize_db is not None:
175
+ audio_signal.normalize(normalize_db)
176
+ audio_signal.ensure_max_of_audio()
177
+
178
+ nb, nac, nt = audio_signal.audio_data.shape
179
+ audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt)
180
+ win_duration = audio_signal.signal_duration if win_duration is None else win_duration
181
+
182
+ if audio_signal.signal_duration <= win_duration:
183
+ # Unchunked compression (used if signal length < win duration)
184
+ self.padding = True
185
+ n_samples = nt
186
+ hop = nt
187
+ else:
188
+ # Chunked inference
189
+ self.padding = False
190
+ # Zero-pad signal on either side by the delay
191
+ audio_signal.zero_pad(self.delay, self.delay)
192
+ n_samples = int(win_duration * self.sample_rate)
193
+ # Round n_samples to nearest hop length multiple
194
+ n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length)
195
+ hop = self.get_output_length(n_samples)
196
+
197
+ codes = []
198
+ range_fn = range if not verbose else tqdm.trange
199
+
200
+ for i in range_fn(0, nt, hop):
201
+ x = audio_signal[..., i : i + n_samples]
202
+ x = x.zero_pad(0, max(0, n_samples - x.shape[-1]))
203
+
204
+ audio_data = x.audio_data.to(self.device)
205
+ audio_data = self.preprocess(audio_data, self.sample_rate)
206
+ _, c, _, _, _ = self.encode(audio_data, n_quantizers)
207
+ codes.append(c.to(original_device))
208
+ chunk_length = c.shape[-1]
209
+
210
+ codes = torch.cat(codes, dim=-1)
211
+
212
+ dac_file = DACFile(
213
+ codes=codes,
214
+ chunk_length=chunk_length,
215
+ original_length=original_length,
216
+ input_db=input_db,
217
+ channels=nac,
218
+ sample_rate=original_sr,
219
+ padding=self.padding,
220
+ dac_version=SUPPORTED_VERSIONS[-1],
221
+ )
222
+
223
+ if n_quantizers is not None:
224
+ codes = codes[:, :n_quantizers, :]
225
+
226
+ self.padding = original_padding
227
+ return dac_file
228
+
229
+ @torch.no_grad()
230
+ def decompress(
231
+ self,
232
+ obj: Union[str, Path, DACFile],
233
+ verbose: bool = False,
234
+ ) -> AudioSignal:
235
+ """Reconstruct audio from a given .dac file
236
+
237
+ Parameters
238
+ ----------
239
+ obj : Union[str, Path, DACFile]
240
+ .dac file location or corresponding DACFile object.
241
+ verbose : bool, optional
242
+ Prints progress if True, by default False
243
+
244
+ Returns
245
+ -------
246
+ AudioSignal
247
+ Object with the reconstructed audio
248
+ """
249
+ self.eval()
250
+ if isinstance(obj, (str, Path)):
251
+ obj = DACFile.load(obj)
252
+
253
+ original_padding = self.padding
254
+ self.padding = obj.padding
255
+
256
+ range_fn = range if not verbose else tqdm.trange
257
+ codes = obj.codes
258
+ original_device = codes.device
259
+ chunk_length = obj.chunk_length
260
+ recons = []
261
+
262
+ for i in range_fn(0, codes.shape[-1], chunk_length):
263
+ c = codes[..., i : i + chunk_length].to(self.device)
264
+ z = self.quantizer.from_codes(c)[0]
265
+ r = self.decode(z)
266
+ recons.append(r.to(original_device))
267
+
268
+ recons = torch.cat(recons, dim=-1)
269
+ recons = AudioSignal(recons, self.sample_rate)
270
+
271
+ resample_fn = recons.resample
272
+ loudness_fn = recons.loudness
273
+
274
+ # If audio is > 10 minutes long, use the ffmpeg versions
275
+ if recons.signal_duration >= 10 * 60 * 60:
276
+ resample_fn = recons.ffmpeg_resample
277
+ loudness_fn = recons.ffmpeg_loudness
278
+
279
+ recons.normalize(obj.input_db)
280
+ resample_fn(obj.sample_rate)
281
+ recons = recons[..., : obj.original_length]
282
+ loudness_fn()
283
+ recons.audio_data = recons.audio_data.reshape(-1, obj.channels, obj.original_length)
284
+
285
+ self.padding = original_padding
286
+ return recons
descriptaudiocodec/dac/model/dac.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List
3
+ from typing import Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from audiotools import AudioSignal
8
+ from audiotools.ml import BaseModel
9
+ from torch import nn
10
+
11
+ from .base import CodecMixin
12
+ from dac.nn.layers import Snake1d
13
+ from dac.nn.layers import WNConv1d
14
+ from dac.nn.layers import WNConvTranspose1d
15
+ from dac.nn.quantize import ResidualVectorQuantize
16
+
17
+
18
+ def init_weights(m):
19
+ if isinstance(m, nn.Conv1d):
20
+ nn.init.trunc_normal_(m.weight, std=0.02)
21
+ nn.init.constant_(m.bias, 0)
22
+
23
+
24
+ class ResidualUnit(nn.Module):
25
+ def __init__(self, dim: int = 16, dilation: int = 1):
26
+ super().__init__()
27
+ pad = ((7 - 1) * dilation) // 2
28
+ self.block = nn.Sequential(
29
+ Snake1d(dim),
30
+ WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
31
+ Snake1d(dim),
32
+ WNConv1d(dim, dim, kernel_size=1),
33
+ )
34
+
35
+ def forward(self, x):
36
+ y = self.block(x)
37
+ pad = (x.shape[-1] - y.shape[-1]) // 2
38
+ if pad > 0:
39
+ x = x[..., pad:-pad]
40
+ return x + y
41
+
42
+
43
+ class EncoderBlock(nn.Module):
44
+ def __init__(self, dim: int = 16, stride: int = 1):
45
+ super().__init__()
46
+ self.block = nn.Sequential(
47
+ ResidualUnit(dim // 2, dilation=1),
48
+ ResidualUnit(dim // 2, dilation=3),
49
+ ResidualUnit(dim // 2, dilation=9),
50
+ Snake1d(dim // 2),
51
+ WNConv1d(
52
+ dim // 2,
53
+ dim,
54
+ kernel_size=2 * stride,
55
+ stride=stride,
56
+ padding=math.ceil(stride / 2),
57
+ ),
58
+ )
59
+
60
+ def forward(self, x):
61
+ return self.block(x)
62
+
63
+
64
+ class Encoder(nn.Module):
65
+ def __init__(
66
+ self,
67
+ d_model: int = 64,
68
+ strides: list = [2, 4, 8, 8],
69
+ d_latent: int = 256,
70
+ ):
71
+ super().__init__()
72
+ # Create first convolution
73
+ self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
74
+
75
+ # Create EncoderBlocks that double channels as they downsample by `stride`
76
+ for stride in strides:
77
+ d_model *= 2
78
+ self.block += [EncoderBlock(d_model, stride=stride)]
79
+
80
+ # Create last convolution
81
+ self.block += [
82
+ Snake1d(d_model),
83
+ WNConv1d(d_model, d_latent, kernel_size=3, padding=1),
84
+ ]
85
+
86
+ # Wrap black into nn.Sequential
87
+ self.block = nn.Sequential(*self.block)
88
+ self.enc_dim = d_model
89
+
90
+ def forward(self, x):
91
+ return self.block(x)
92
+
93
+
94
+ class DecoderBlock(nn.Module):
95
+ def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1, out_pad=0):
96
+ super().__init__()
97
+ self.block = nn.Sequential(
98
+ Snake1d(input_dim),
99
+ WNConvTranspose1d(
100
+ input_dim,
101
+ output_dim,
102
+ kernel_size=2 * stride,
103
+ stride=stride,
104
+ padding=math.ceil(stride / 2),
105
+ output_padding=stride % 2, # out_pad,
106
+ ),
107
+ ResidualUnit(output_dim, dilation=1),
108
+ ResidualUnit(output_dim, dilation=3),
109
+ ResidualUnit(output_dim, dilation=9),
110
+ )
111
+
112
+ def forward(self, x):
113
+ return self.block(x)
114
+
115
+
116
+ class Decoder(nn.Module):
117
+ def __init__(
118
+ self,
119
+ input_channel,
120
+ channels,
121
+ rates,
122
+ d_out: int = 1,
123
+ ):
124
+ super().__init__()
125
+
126
+ # Add first conv layer
127
+ layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
128
+
129
+ # Add upsampling + MRF blocks
130
+ for i, stride in enumerate(rates):
131
+ input_dim = channels // 2**i
132
+ output_dim = channels // 2 ** (i + 1)
133
+ if i == 1:
134
+ out_pad = 1
135
+ else:
136
+ out_pad = 0
137
+ layers += [DecoderBlock(input_dim, output_dim, stride, out_pad)]
138
+
139
+ # Add final conv layer
140
+ layers += [
141
+ Snake1d(output_dim),
142
+ WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
143
+ # nn.Tanh(),
144
+ ]
145
+
146
+ self.model = nn.Sequential(*layers)
147
+
148
+ def forward(self, x):
149
+ return self.model(x)
150
+
151
+
152
+ class DAC(BaseModel, CodecMixin):
153
+ def __init__(
154
+ self,
155
+ encoder_dim: int = 64,
156
+ encoder_rates: List[int] = [2, 4, 8, 8],
157
+ latent_dim: int = None,
158
+ decoder_dim: int = 1536,
159
+ decoder_rates: List[int] = [8, 8, 4, 2],
160
+ n_codebooks: int = 9,
161
+ codebook_size: int = 1024,
162
+ codebook_dim: Union[int, list] = 8,
163
+ quantizer_dropout: bool = False,
164
+ sample_rate: int = 44100,
165
+ ):
166
+ super().__init__()
167
+
168
+ self.encoder_dim = encoder_dim
169
+ self.encoder_rates = encoder_rates
170
+ self.decoder_dim = decoder_dim
171
+ self.decoder_rates = decoder_rates
172
+ self.sample_rate = sample_rate
173
+
174
+ if latent_dim is None:
175
+ latent_dim = encoder_dim * (2 ** len(encoder_rates))
176
+
177
+ self.latent_dim = latent_dim
178
+
179
+ self.hop_length = np.prod(encoder_rates)
180
+ self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim)
181
+
182
+ self.n_codebooks = n_codebooks
183
+ self.codebook_size = codebook_size
184
+ self.codebook_dim = codebook_dim
185
+ self.quantizer = ResidualVectorQuantize(
186
+ input_dim=latent_dim,
187
+ n_codebooks=n_codebooks,
188
+ codebook_size=codebook_size,
189
+ codebook_dim=codebook_dim,
190
+ quantizer_dropout=quantizer_dropout,
191
+ )
192
+
193
+ self.decoder = Decoder(
194
+ latent_dim,
195
+ decoder_dim,
196
+ decoder_rates,
197
+ )
198
+ self.sample_rate = sample_rate
199
+ self.apply(init_weights)
200
+
201
+ self.delay = self.get_delay()
202
+
203
+ def preprocess(self, audio_data, sample_rate):
204
+ if sample_rate is None:
205
+ sample_rate = self.sample_rate
206
+ assert sample_rate == self.sample_rate
207
+
208
+ length = audio_data.shape[-1]
209
+ right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
210
+ audio_data = nn.functional.pad(audio_data, (0, right_pad))
211
+
212
+ return audio_data
213
+
214
+ def encode(
215
+ self,
216
+ audio_data: torch.Tensor,
217
+ n_quantizers: int = None,
218
+ ):
219
+ """Encode given audio data and return quantized latent codes
220
+
221
+ Parameters
222
+ ----------
223
+ audio_data : Tensor[B x 1 x T]
224
+ Audio data to encode
225
+ n_quantizers : int, optional
226
+ Number of quantizers to use, by default None
227
+ If None, all quantizers are used.
228
+
229
+ Returns
230
+ -------
231
+ dict
232
+ A dictionary with the following keys:
233
+ "z" : Tensor[B x D x T]
234
+ Quantized continuous representation of input
235
+ "codes" : Tensor[B x N x T]
236
+ Codebook indices for each codebook
237
+ (quantized discrete representation of input)
238
+ "latents" : Tensor[B x N*D x T]
239
+ Projected latents (continuous representation of input before quantization)
240
+ "vq/commitment_loss" : Tensor[1]
241
+ Commitment loss to train encoder to predict vectors closer to codebook
242
+ entries
243
+ "vq/codebook_loss" : Tensor[1]
244
+ Codebook loss to update the codebook
245
+ "length" : int
246
+ Number of samples in input audio
247
+ """
248
+ z = self.encoder(audio_data)
249
+ z, codes, latents, commitment_loss, codebook_loss = self.quantizer(z, n_quantizers)
250
+ return z, codes, latents, commitment_loss, codebook_loss
251
+
252
+ def decode(self, z: torch.Tensor):
253
+ """Decode given latent codes and return audio data
254
+
255
+ Parameters
256
+ ----------
257
+ z : Tensor[B x D x T]
258
+ Quantized continuous representation of input
259
+ length : int, optional
260
+ Number of samples in output audio, by default None
261
+
262
+ Returns
263
+ -------
264
+ dict
265
+ A dictionary with the following keys:
266
+ "audio" : Tensor[B x 1 x length]
267
+ Decoded audio data.
268
+ """
269
+ return self.decoder(z)
270
+
271
+ def forward(
272
+ self,
273
+ audio_data: torch.Tensor,
274
+ sample_rate: int = None,
275
+ n_quantizers: int = None,
276
+ ):
277
+ """Model forward pass
278
+
279
+ Parameters
280
+ ----------
281
+ audio_data : Tensor[B x 1 x T]
282
+ Audio data to encode
283
+ sample_rate : int, optional
284
+ Sample rate of audio data in Hz, by default None
285
+ If None, defaults to `self.sample_rate`
286
+ n_quantizers : int, optional
287
+ Number of quantizers to use, by default None.
288
+ If None, all quantizers are used.
289
+
290
+ Returns
291
+ -------
292
+ dict
293
+ A dictionary with the following keys:
294
+ "z" : Tensor[B x D x T]
295
+ Quantized continuous representation of input
296
+ "codes" : Tensor[B x N x T]
297
+ Codebook indices for each codebook
298
+ (quantized discrete representation of input)
299
+ "latents" : Tensor[B x N*D x T]
300
+ Projected latents (continuous representation of input before quantization)
301
+ "vq/commitment_loss" : Tensor[1]
302
+ Commitment loss to train encoder to predict vectors closer to codebook
303
+ entries
304
+ "vq/codebook_loss" : Tensor[1]
305
+ Codebook loss to update the codebook
306
+ "length" : int
307
+ Number of samples in input audio
308
+ "audio" : Tensor[B x 1 x length]
309
+ Decoded audio data.
310
+ """
311
+ length = audio_data.shape[-1]
312
+ audio_data = self.preprocess(audio_data, sample_rate)
313
+ z, codes, latents, commitment_loss, codebook_loss = self.encode(audio_data, n_quantizers)
314
+
315
+ x = self.decode(z)
316
+ return {
317
+ "audio": x[..., :length],
318
+ "z": z,
319
+ "codes": codes,
320
+ "latents": latents,
321
+ "vq/commitment_loss": commitment_loss,
322
+ "vq/codebook_loss": codebook_loss,
323
+ }
324
+
325
+
326
+ if __name__ == "__main__":
327
+ import numpy as np
328
+ from functools import partial
329
+
330
+ model = DAC().to("cpu")
331
+
332
+ for n, m in model.named_modules():
333
+ o = m.extra_repr()
334
+ p = sum([np.prod(p.size()) for p in m.parameters()])
335
+ fn = lambda o, p: o + f" {p / 1e6:<.3f}M params."
336
+ setattr(m, "extra_repr", partial(fn, o=o, p=p))
337
+ print(model)
338
+ print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()]))
339
+
340
+ length = 88200 * 2
341
+ x = torch.randn(1, 1, length).to(model.device)
342
+ x.requires_grad_(True)
343
+ x.retain_grad()
344
+
345
+ # Make a forward pass
346
+ out = model(x)["audio"]
347
+ print("Input shape:", x.shape)
348
+ print("Output shape:", out.shape)
349
+
350
+ # Create gradient variable
351
+ grad = torch.zeros_like(out)
352
+ grad[:, :, grad.shape[-1] // 2] = 1
353
+
354
+ # Make a backward pass
355
+ out.backward(grad)
356
+
357
+ # Check non-zero values
358
+ gradmap = x.grad.squeeze(0)
359
+ gradmap = (gradmap != 0).sum(0) # sum across features
360
+ rf = (gradmap != 0).sum()
361
+
362
+ print(f"Receptive field: {rf.item()}")
363
+
364
+ x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100)
365
+ model.decompress(model.compress(x, verbose=True), verbose=True)
descriptaudiocodec/dac/nn/layers.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from einops import rearrange
6
+ from torch.nn.utils import weight_norm
7
+
8
+
9
+ def WNConv1d(*args, **kwargs):
10
+ return weight_norm(nn.Conv1d(*args, **kwargs))
11
+
12
+
13
+ def WNConvTranspose1d(*args, **kwargs):
14
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
15
+
16
+
17
+ # Scripting this brings model speed up 1.4x
18
+ @torch.jit.script
19
+ def snake(x, alpha):
20
+ shape = x.shape
21
+ x = x.reshape(shape[0], shape[1], -1)
22
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
23
+ x = x.reshape(shape)
24
+ return x
25
+
26
+
27
+ class Snake1d(nn.Module):
28
+ def __init__(self, channels):
29
+ super().__init__()
30
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1))
31
+
32
+ def forward(self, x):
33
+ return snake(x, self.alpha)
descriptaudiocodec/dac/nn/quantize.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from torch.nn.utils import weight_norm
9
+
10
+ from dac.nn.layers import WNConv1d
11
+
12
+
13
+ class VectorQuantize(nn.Module):
14
+ """
15
+ Implementation of VQ similar to Karpathy's repo:
16
+ https://github.com/karpathy/deep-vector-quantization
17
+ Additionally uses following tricks from Improved VQGAN
18
+ (https://arxiv.org/pdf/2110.04627.pdf):
19
+ 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
20
+ for improved codebook usage
21
+ 2. l2-normalized codes: Converts euclidean distance to cosine similarity which
22
+ improves training stability
23
+ """
24
+
25
+ def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
26
+ super().__init__()
27
+ self.codebook_size = codebook_size
28
+ self.codebook_dim = codebook_dim
29
+
30
+ self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
31
+ self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
32
+ self.codebook = nn.Embedding(codebook_size, codebook_dim)
33
+
34
+ def forward(self, z):
35
+ """Quantized the input tensor using a fixed codebook and returns
36
+ the corresponding codebook vectors
37
+
38
+ Parameters
39
+ ----------
40
+ z : Tensor[B x D x T]
41
+
42
+ Returns
43
+ -------
44
+ Tensor[B x D x T]
45
+ Quantized continuous representation of input
46
+ Tensor[1]
47
+ Commitment loss to train encoder to predict vectors closer to codebook
48
+ entries
49
+ Tensor[1]
50
+ Codebook loss to update the codebook
51
+ Tensor[B x T]
52
+ Codebook indices (quantized discrete representation of input)
53
+ Tensor[B x D x T]
54
+ Projected latents (continuous representation of input before quantization)
55
+ """
56
+
57
+ # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
58
+ z_e = self.in_proj(z) # z_e : (B x D x T)
59
+ z_q, indices = self.decode_latents(z_e)
60
+
61
+ commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
62
+ codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
63
+
64
+ z_q = z_e + (z_q - z_e).detach() # noop in forward pass, straight-through gradient estimator in backward pass
65
+
66
+ z_q = self.out_proj(z_q)
67
+
68
+ return z_q, commitment_loss, codebook_loss, indices, z_e
69
+
70
+ def embed_code(self, embed_id):
71
+ return F.embedding(embed_id, self.codebook.weight)
72
+
73
+ def decode_code(self, embed_id):
74
+ return self.embed_code(embed_id).transpose(1, 2)
75
+
76
+ def decode_latents(self, latents):
77
+ encodings = rearrange(latents, "b d t -> (b t) d")
78
+ codebook = self.codebook.weight # codebook: (N x D)
79
+
80
+ # L2 normalize encodings and codebook (ViT-VQGAN)
81
+ encodings = F.normalize(encodings)
82
+ codebook = F.normalize(codebook)
83
+
84
+ # Compute euclidean distance with codebook
85
+ dist = (
86
+ encodings.pow(2).sum(1, keepdim=True)
87
+ - 2 * encodings @ codebook.t()
88
+ + codebook.pow(2).sum(1, keepdim=True).t()
89
+ )
90
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
91
+ z_q = self.decode_code(indices)
92
+ return z_q, indices
93
+
94
+
95
+ class ResidualVectorQuantize(nn.Module):
96
+ """
97
+ Introduced in SoundStream: An end2end neural audio codec
98
+ https://arxiv.org/abs/2107.03312
99
+ """
100
+
101
+ def __init__(
102
+ self,
103
+ input_dim: int = 512,
104
+ n_codebooks: int = 9,
105
+ codebook_size: int = 1024,
106
+ codebook_dim: Union[int, list] = 8,
107
+ quantizer_dropout: float = 0.0,
108
+ ):
109
+ super().__init__()
110
+ if isinstance(codebook_dim, int):
111
+ codebook_dim = [codebook_dim for _ in range(n_codebooks)]
112
+
113
+ self.n_codebooks = n_codebooks
114
+ self.codebook_dim = codebook_dim
115
+ self.codebook_size = codebook_size
116
+
117
+ self.quantizers = nn.ModuleList(
118
+ [VectorQuantize(input_dim, codebook_size, codebook_dim[i]) for i in range(n_codebooks)]
119
+ )
120
+ self.quantizer_dropout = quantizer_dropout
121
+
122
+ def forward(self, z, n_quantizers: int = None):
123
+ """Quantized the input tensor using a fixed set of `n` codebooks and returns
124
+ the corresponding codebook vectors
125
+ Parameters
126
+ ----------
127
+ z : Tensor[B x D x T]
128
+ n_quantizers : int, optional
129
+ No. of quantizers to use
130
+ (n_quantizers < self.n_codebooks ex: for quantizer dropout)
131
+ Note: if `self.quantizer_dropout` is True, this argument is ignored
132
+ when in training mode, and a random number of quantizers is used.
133
+ Returns
134
+ -------
135
+ dict
136
+ A dictionary with the following keys:
137
+
138
+ "z" : Tensor[B x D x T]
139
+ Quantized continuous representation of input
140
+ "codes" : Tensor[B x N x T]
141
+ Codebook indices for each codebook
142
+ (quantized discrete representation of input)
143
+ "latents" : Tensor[B x N*D x T]
144
+ Projected latents (continuous representation of input before quantization)
145
+ "vq/commitment_loss" : Tensor[1]
146
+ Commitment loss to train encoder to predict vectors closer to codebook
147
+ entries
148
+ "vq/codebook_loss" : Tensor[1]
149
+ Codebook loss to update the codebook
150
+ """
151
+ z_q = 0
152
+ residual = z
153
+ commitment_loss = 0
154
+ codebook_loss = 0
155
+
156
+ codebook_indices = []
157
+ latents = []
158
+
159
+ if n_quantizers is None:
160
+ n_quantizers = self.n_codebooks
161
+ if self.training:
162
+ n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
163
+ dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
164
+ n_dropout = int(z.shape[0] * self.quantizer_dropout)
165
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
166
+ n_quantizers = n_quantizers.to(z.device)
167
+
168
+ for i, quantizer in enumerate(self.quantizers):
169
+ if self.training is False and i >= n_quantizers:
170
+ break
171
+
172
+ z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(residual)
173
+
174
+ # Create mask to apply quantizer dropout
175
+ mask = torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
176
+ z_q = z_q + z_q_i * mask[:, None, None]
177
+ residual = residual - z_q_i
178
+
179
+ # Sum losses
180
+ commitment_loss += (commitment_loss_i * mask).mean()
181
+ codebook_loss += (codebook_loss_i * mask).mean()
182
+
183
+ codebook_indices.append(indices_i)
184
+ latents.append(z_e_i)
185
+
186
+ codes = torch.stack(codebook_indices, dim=1)
187
+ latents = torch.cat(latents, dim=1)
188
+
189
+ return z_q, codes, latents, commitment_loss, codebook_loss
190
+
191
+ def from_codes(self, codes: torch.Tensor):
192
+ """Given the quantized codes, reconstruct the continuous representation
193
+ Parameters
194
+ ----------
195
+ codes : Tensor[B x N x T]
196
+ Quantized discrete representation of input
197
+ Returns
198
+ -------
199
+ Tensor[B x D x T]
200
+ Quantized continuous representation of input
201
+ """
202
+ z_q = 0.0
203
+ z_p = []
204
+ n_codebooks = codes.shape[1]
205
+ for i in range(n_codebooks):
206
+ z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
207
+ z_p.append(z_p_i)
208
+
209
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
210
+ z_q = z_q + z_q_i
211
+ return z_q, torch.cat(z_p, dim=1), codes
212
+
213
+ def from_latents(self, latents: torch.Tensor):
214
+ """Given the unquantized latents, reconstruct the
215
+ continuous representation after quantization.
216
+
217
+ Parameters
218
+ ----------
219
+ latents : Tensor[B x N x T]
220
+ Continuous representation of input after projection
221
+
222
+ Returns
223
+ -------
224
+ Tensor[B x D x T]
225
+ Quantized representation of full-projected space
226
+ Tensor[B x D x T]
227
+ Quantized representation of latent space
228
+ """
229
+ z_q = 0
230
+ z_p = []
231
+ codes = []
232
+ dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
233
+
234
+ n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[0]
235
+ for i in range(n_codebooks):
236
+ j, k = dims[i], dims[i + 1]
237
+ z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
238
+ z_p.append(z_p_i)
239
+ codes.append(codes_i)
240
+
241
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
242
+ z_q = z_q + z_q_i
243
+
244
+ return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
245
+
246
+
247
+ if __name__ == "__main__":
248
+ rvq = ResidualVectorQuantize(quantizer_dropout=True)
249
+ x = torch.randn(16, 512, 80)
250
+ y = rvq(x)
251
+ print(y["latents"].shape)
discriminator.py ADDED
@@ -0,0 +1,596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import torch
2
+ # import torch.nn as nn
3
+ # import torch.nn.functional as F
4
+ # from audiotools import AudioSignal
5
+ # from audiotools import ml
6
+ # from audiotools import STFTParams
7
+ # from einops import rearrange
8
+ # from torch.nn.utils import weight_norm
9
+
10
+
11
+ # def WNConv1d(*args, **kwargs):
12
+ # act = kwargs.pop("act", True)
13
+ # conv = weight_norm(nn.Conv1d(*args, **kwargs))
14
+ # if not act:
15
+ # return conv
16
+ # return nn.Sequential(conv, nn.LeakyReLU(0.1))
17
+
18
+
19
+ # def WNConv2d(*args, **kwargs):
20
+ # act = kwargs.pop("act", True)
21
+ # conv = weight_norm(nn.Conv2d(*args, **kwargs))
22
+ # if not act:
23
+ # return conv
24
+ # return nn.Sequential(conv, nn.LeakyReLU(0.1))
25
+
26
+
27
+ # class MPD(nn.Module):
28
+ # def __init__(self, period):
29
+ # super().__init__()
30
+ # self.period = period
31
+ # self.convs = nn.ModuleList(
32
+ # [
33
+ # WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)),
34
+ # WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)),
35
+ # WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)),
36
+ # WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)),
37
+ # WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)),
38
+ # ]
39
+ # )
40
+ # self.conv_post = WNConv2d(
41
+ # 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False
42
+ # )
43
+
44
+ # def pad_to_period(self, x):
45
+ # t = x.shape[-1]
46
+ # x = F.pad(x, (0, self.period - t % self.period), mode="reflect")
47
+ # return x
48
+
49
+ # def forward(self, x):
50
+ # fmap = []
51
+
52
+ # x = self.pad_to_period(x)
53
+ # x = rearrange(x, "b c (l p) -> b c l p", p=self.period)
54
+
55
+ # for layer in self.convs:
56
+ # x = layer(x)
57
+ # fmap.append(x)
58
+
59
+ # x = self.conv_post(x)
60
+ # fmap.append(x)
61
+
62
+ # return fmap
63
+
64
+
65
+ # class MSD(nn.Module):
66
+ # def __init__(self, rate: int = 1, sample_rate: int = 44100):
67
+ # super().__init__()
68
+ # self.convs = nn.ModuleList(
69
+ # [
70
+ # WNConv1d(1, 16, 15, 1, padding=7),
71
+ # WNConv1d(16, 64, 41, 4, groups=4, padding=20),
72
+ # WNConv1d(64, 256, 41, 4, groups=16, padding=20),
73
+ # WNConv1d(256, 1024, 41, 4, groups=64, padding=20),
74
+ # WNConv1d(1024, 1024, 41, 4, groups=256, padding=20),
75
+ # WNConv1d(1024, 1024, 5, 1, padding=2),
76
+ # ]
77
+ # )
78
+ # self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False)
79
+ # self.sample_rate = sample_rate
80
+ # self.rate = rate
81
+
82
+ # def forward(self, x):
83
+ # x = AudioSignal(x, self.sample_rate)
84
+ # x.resample(self.sample_rate // self.rate)
85
+ # x = x.audio_data
86
+
87
+ # fmap = []
88
+
89
+ # for l in self.convs:
90
+ # x = l(x)
91
+ # fmap.append(x)
92
+ # x = self.conv_post(x)
93
+ # fmap.append(x)
94
+
95
+ # return fmap
96
+
97
+
98
+ # BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)]
99
+
100
+
101
+ # class MRD(nn.Module):
102
+ # def __init__(
103
+ # self,
104
+ # window_length: int,
105
+ # hop_factor: float = 0.25,
106
+ # sample_rate: int = 44100,
107
+ # bands: list = BANDS,
108
+ # ):
109
+ # """Complex multi-band spectrogram discriminator.
110
+ # Parameters
111
+ # ----------
112
+ # window_length : int
113
+ # Window length of STFT.
114
+ # hop_factor : float, optional
115
+ # Hop factor of the STFT, defaults to ``0.25 * window_length``.
116
+ # sample_rate : int, optional
117
+ # Sampling rate of audio in Hz, by default 44100
118
+ # bands : list, optional
119
+ # Bands to run discriminator over.
120
+ # """
121
+ # super().__init__()
122
+
123
+ # self.window_length = window_length
124
+ # self.hop_factor = hop_factor
125
+ # self.sample_rate = sample_rate
126
+ # self.stft_params = STFTParams(
127
+ # window_length=window_length,
128
+ # hop_length=int(window_length * hop_factor),
129
+ # match_stride=True,
130
+ # )
131
+
132
+ # n_fft = window_length // 2 + 1
133
+ # bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
134
+ # self.bands = bands
135
+
136
+ # ch = 32
137
+ # convs = lambda: nn.ModuleList(
138
+ # [
139
+ # WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)),
140
+ # WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
141
+ # WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
142
+ # WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
143
+ # WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)),
144
+ # ]
145
+ # )
146
+ # self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
147
+ # self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False)
148
+
149
+ # def spectrogram(self, x):
150
+ # x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params)
151
+ # x = torch.view_as_real(x.stft())
152
+ # x = rearrange(x, "b 1 f t c -> (b 1) c t f")
153
+ # # Split into bands
154
+ # x_bands = [x[..., b[0] : b[1]] for b in self.bands]
155
+ # return x_bands
156
+
157
+ # def forward(self, x):
158
+ # x_bands = self.spectrogram(x)
159
+ # fmap = []
160
+
161
+ # x = []
162
+ # for band, stack in zip(x_bands, self.band_convs):
163
+ # for layer in stack:
164
+ # band = layer(band)
165
+ # fmap.append(band)
166
+ # x.append(band)
167
+
168
+ # x = torch.cat(x, dim=-1)
169
+ # x = self.conv_post(x)
170
+ # fmap.append(x)
171
+
172
+ # return fmap
173
+
174
+
175
+ # class Discriminator(ml.BaseModel):
176
+ # def __init__(
177
+ # self,
178
+ # rates: list = [],
179
+ # periods: list = [2, 3, 5, 7, 11],
180
+ # fft_sizes: list = [2048, 1024, 512],
181
+ # sample_rate: int = 44100,
182
+ # bands: list = BANDS,
183
+ # ):
184
+ # """Discriminator that combines multiple discriminators.
185
+
186
+ # Parameters
187
+ # ----------
188
+ # rates : list, optional
189
+ # sampling rates (in Hz) to run MSD at, by default []
190
+ # If empty, MSD is not used.
191
+ # periods : list, optional
192
+ # periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11]
193
+ # fft_sizes : list, optional
194
+ # Window sizes of the FFT to run MRD at, by default [2048, 1024, 512]
195
+ # sample_rate : int, optional
196
+ # Sampling rate of audio in Hz, by default 44100
197
+ # bands : list, optional
198
+ # Bands to run MRD at, by default `BANDS`
199
+ # """
200
+ # super().__init__()
201
+ # discs = []
202
+ # discs += [MPD(p) for p in periods]
203
+ # discs += [MSD(r, sample_rate=sample_rate) for r in rates]
204
+ # discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes]
205
+ # self.discriminators = nn.ModuleList(discs)
206
+
207
+ # def preprocess(self, y):
208
+ # # Remove DC offset
209
+ # y = y - y.mean(dim=-1, keepdims=True)
210
+ # # Peak normalize the volume of input audio
211
+ # y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
212
+ # return y
213
+
214
+ # def forward(self, x):
215
+ # x = self.preprocess(x)
216
+ # fmaps = [d(x) for d in self.discriminators]
217
+ # return fmaps
218
+
219
+
220
+ # if __name__ == "__main__":
221
+ # disc = Discriminator()
222
+ # x = torch.zeros(1, 1, 44100)
223
+ # results = disc(x)
224
+ # for i, result in enumerate(results):
225
+ # print(f"disc{i}")
226
+ # for i, r in enumerate(result):
227
+ # print(r.shape, r.mean(), r.min(), r.max())
228
+ # print()
229
+ import torch
230
+ import torch.nn as nn
231
+ import torch.nn.functional as F
232
+ from audiotools import AudioSignal, STFTParams
233
+ from audiotools import ml
234
+ from einops import rearrange
235
+ from torch.nn.utils import weight_norm
236
+ import torchaudio
237
+ import nnAudio.features as features
238
+ from munch import Munch
239
+
240
+
241
+ BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)]
242
+
243
+
244
+ def WNConv1d(*args, **kwargs):
245
+ act = kwargs.pop("act", True)
246
+ conv = weight_norm(nn.Conv1d(*args, **kwargs))
247
+ if not act:
248
+ return conv
249
+ return nn.Sequential(conv, nn.LeakyReLU(0.1))
250
+
251
+
252
+ def WNConv2d(*args, **kwargs):
253
+ act = kwargs.pop("act", True)
254
+ conv = weight_norm(nn.Conv2d(*args, **kwargs))
255
+ if not act:
256
+ return conv
257
+ return nn.Sequential(conv, nn.LeakyReLU(0.1))
258
+
259
+
260
+ def get_padding(kernel_size, dilation=1):
261
+ return int((kernel_size * dilation - dilation) / 2)
262
+
263
+
264
+ def get_2d_padding(kernel_size, dilation=(1, 1)):
265
+ return (int((kernel_size[0] * dilation[0] - dilation[0]) / 2),
266
+ int((kernel_size[1] * dilation[1] - dilation[1]) / 2))
267
+
268
+
269
+ class NormConv2d(nn.Module):
270
+ """Conv2d with normalization"""
271
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
272
+ padding=0, dilation=1, groups=1, bias=True, norm="weight_norm"):
273
+ super().__init__()
274
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size,
275
+ stride, padding, dilation, groups, bias)
276
+ if norm == "weight_norm":
277
+ self.conv = weight_norm(self.conv)
278
+
279
+ def forward(self, x):
280
+ return self.conv(x)
281
+
282
+
283
+ class MPD(nn.Module):
284
+ def __init__(self, period):
285
+ super().__init__()
286
+ self.period = period
287
+ self.convs = nn.ModuleList([
288
+ WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)),
289
+ WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)),
290
+ WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)),
291
+ WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)),
292
+ WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)),
293
+ ])
294
+ self.conv_post = WNConv2d(1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False)
295
+
296
+ def pad_to_period(self, x):
297
+ t = x.shape[-1]
298
+ x = F.pad(x, (0, self.period - t % self.period), mode="reflect")
299
+ return x
300
+
301
+ def forward(self, x):
302
+ fmap = []
303
+ x = self.pad_to_period(x)
304
+ x = rearrange(x, "b c (l p) -> b c l p", p=self.period)
305
+
306
+ for layer in self.convs:
307
+ x = layer(x)
308
+ fmap.append(x)
309
+
310
+ x = self.conv_post(x)
311
+ fmap.append(x)
312
+ return fmap
313
+
314
+
315
+ class MSD(nn.Module):
316
+ def __init__(self, rate: int = 1, sample_rate: int = 44100):
317
+ super().__init__()
318
+ self.convs = nn.ModuleList([
319
+ WNConv1d(1, 16, 15, 1, padding=7),
320
+ WNConv1d(16, 64, 41, 4, groups=4, padding=20),
321
+ WNConv1d(64, 256, 41, 4, groups=16, padding=20),
322
+ WNConv1d(256, 1024, 41, 4, groups=64, padding=20),
323
+ WNConv1d(1024, 1024, 41, 4, groups=256, padding=20),
324
+ WNConv1d(1024, 1024, 5, 1, padding=2),
325
+ ])
326
+ self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False)
327
+ self.sample_rate = sample_rate
328
+ self.rate = rate
329
+
330
+ def forward(self, x):
331
+ x = AudioSignal(x, self.sample_rate)
332
+ x.resample(self.sample_rate // self.rate)
333
+ x = x.audio_data
334
+
335
+ fmap = []
336
+ for l in self.convs:
337
+ x = l(x)
338
+ fmap.append(x)
339
+ x = self.conv_post(x)
340
+ fmap.append(x)
341
+ return fmap
342
+
343
+
344
+ class DiscriminatorCQT(nn.Module):
345
+ def __init__(self, cfg, hop_length, n_octaves, bins_per_octave):
346
+ super().__init__()
347
+ self.cfg = cfg
348
+ self.filters = cfg.filters
349
+ self.max_filters = cfg.max_filters
350
+ self.filters_scale = cfg.filters_scale
351
+ self.kernel_size = (3, 9)
352
+ self.dilations = cfg.dilations
353
+ self.stride = (1, 2)
354
+ self.in_channels = cfg.in_channels
355
+ self.out_channels = cfg.out_channels
356
+ self.fs = cfg.sampling_rate
357
+ self.hop_length = hop_length
358
+ self.n_octaves = n_octaves
359
+ self.bins_per_octave = bins_per_octave
360
+
361
+ self.cqt_transform = features.cqt.CQT2010v2(
362
+ sr=self.fs * 2,
363
+ hop_length=self.hop_length,
364
+ n_bins=self.bins_per_octave * self.n_octaves,
365
+ bins_per_octave=self.bins_per_octave,
366
+ output_format="Complex",
367
+ pad_mode="constant",
368
+ )
369
+
370
+ self.conv_pres = nn.ModuleList()
371
+ for i in range(self.n_octaves):
372
+ self.conv_pres.append(
373
+ NormConv2d(
374
+ self.in_channels * 2, # Real + Imaginary
375
+ self.in_channels * 2,
376
+ kernel_size=self.kernel_size,
377
+ padding=get_2d_padding(self.kernel_size),
378
+ norm="weight_norm",
379
+ )
380
+ )
381
+
382
+ self.convs = nn.ModuleList()
383
+ self.convs.append(
384
+ NormConv2d(
385
+ self.in_channels * 2,
386
+ self.filters,
387
+ kernel_size=self.kernel_size,
388
+ padding=get_2d_padding(self.kernel_size),
389
+ )
390
+ )
391
+
392
+ in_chs = min(self.filters_scale * self.filters, self.max_filters)
393
+ for i, dilation in enumerate(self.dilations):
394
+ out_chs = min((self.filters_scale ** (i + 1)) * self.filters, self.max_filters)
395
+ self.convs.append(
396
+ NormConv2d(
397
+ in_chs,
398
+ out_chs,
399
+ kernel_size=self.kernel_size,
400
+ stride=self.stride,
401
+ dilation=(dilation, 1),
402
+ padding=get_2d_padding(self.kernel_size, (dilation, 1)),
403
+ norm="weight_norm",
404
+ )
405
+ )
406
+ in_chs = out_chs
407
+
408
+ out_chs = min(
409
+ (self.filters_scale ** (len(self.dilations) + 1)) * self.filters,
410
+ self.max_filters,
411
+ )
412
+ self.convs.append(
413
+ NormConv2d(
414
+ in_chs,
415
+ out_chs,
416
+ kernel_size=(self.kernel_size[0], self.kernel_size[0]),
417
+ padding=get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
418
+ norm="weight_norm",
419
+ )
420
+ )
421
+
422
+ self.conv_post = NormConv2d(
423
+ out_chs,
424
+ self.out_channels,
425
+ kernel_size=(self.kernel_size[0], self.kernel_size[0]),
426
+ padding=get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
427
+ norm="weight_norm",
428
+ )
429
+
430
+ self.activation = torch.nn.LeakyReLU(negative_slope=0.1)
431
+ self.resample = torchaudio.transforms.Resample(
432
+ orig_freq=self.fs, new_freq=self.fs * 2
433
+ )
434
+
435
+ def forward(self, x):
436
+ fmap = []
437
+ x = self.resample(x)
438
+ z = self.cqt_transform(x)
439
+
440
+
441
+ z_amplitude = z[:, :, :, 0].unsqueeze(1)
442
+ z_phase = z[:, :, :, 1].unsqueeze(1)
443
+ z = torch.cat([z_amplitude, z_phase], dim=1)
444
+ z = rearrange(z, "b c w t -> b c t w")
445
+
446
+ latent_z = []
447
+ for i in range(self.n_octaves):
448
+ octave_band = z[:, :, :, i * self.bins_per_octave : (i + 1) * self.bins_per_octave]
449
+ processed_band = self.conv_pres[i](octave_band)
450
+ latent_z.append(processed_band)
451
+ latent_z = torch.cat(latent_z, dim=-1)
452
+
453
+ for i, l in enumerate(self.convs):
454
+ latent_z = l(latent_z)
455
+ latent_z = self.activation(latent_z)
456
+ fmap.append(latent_z)
457
+
458
+ latent_z = self.conv_post(latent_z)
459
+ fmap.append(latent_z)
460
+
461
+ return fmap
462
+
463
+
464
+ class MultiScaleSubbandCQT(nn.Module):
465
+ """CQT discriminator at multiple scales"""
466
+ def __init__(self, sample_rate=44100):
467
+ super().__init__()
468
+ cfg = Munch({
469
+ "hop_lengths": [1024, 512, 512],
470
+ "sampling_rate": sample_rate,
471
+ "filters": 32,
472
+ "max_filters": 1024,
473
+ "filters_scale": 1,
474
+ "dilations": [1, 2, 4],
475
+ "in_channels": 1,
476
+ "out_channels": 1,
477
+ "n_octaves": [10, 10, 10],
478
+ "bins_per_octaves": [24, 36, 48],
479
+ })
480
+ self.cfg = cfg
481
+ self.discriminators = nn.ModuleList([
482
+ DiscriminatorCQT(
483
+ cfg,
484
+ hop_length=cfg.hop_lengths[i],
485
+ n_octaves=cfg.n_octaves[i],
486
+ bins_per_octave=cfg.bins_per_octaves[i],
487
+ )
488
+ for i in range(len(cfg.hop_lengths))
489
+ ])
490
+
491
+ def forward(self, x):
492
+ fmap = []
493
+ for disc in self.discriminators:
494
+ fmap.extend(disc(x))
495
+ return fmap
496
+
497
+
498
+ BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)]
499
+
500
+ class MRD(nn.Module):
501
+ def __init__(self, window_length: int, hop_factor: float = 0.25,
502
+ sample_rate: int = 44100, bands: list = BANDS):
503
+ """Multi-resolution spectrogram discriminator."""
504
+ super().__init__()
505
+ self.window_length = window_length
506
+ self.hop_factor = hop_factor
507
+ self.sample_rate = sample_rate
508
+ self.stft_params = STFTParams(
509
+ window_length=window_length,
510
+ hop_length=int(window_length * hop_factor),
511
+ match_stride=True,
512
+ )
513
+
514
+ n_fft = window_length // 2 + 1
515
+ bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
516
+ self.bands = bands
517
+
518
+ ch = 32
519
+ convs = lambda: nn.ModuleList([
520
+ WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)),
521
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
522
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
523
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
524
+ WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)),
525
+ ])
526
+ self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
527
+ self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False)
528
+
529
+ def spectrogram(self, x):
530
+ x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params)
531
+ x = torch.view_as_real(x.stft())
532
+ x = rearrange(x, "b 1 f t c -> (b 1) c t f")
533
+ x_bands = [x[..., b[0] : b[1]] for b in self.bands]
534
+ return x_bands
535
+
536
+ def forward(self, x):
537
+ x_bands = self.spectrogram(x)
538
+ fmap = []
539
+
540
+ x = []
541
+ for band, stack in zip(x_bands, self.band_convs):
542
+ for layer in stack:
543
+ band = layer(band)
544
+ fmap.append(band)
545
+ x.append(band)
546
+
547
+ x = torch.cat(x, dim=-1)
548
+ x = self.conv_post(x)
549
+ fmap.append(x)
550
+ return fmap
551
+
552
+
553
+ class Discriminator(ml.BaseModel):
554
+ def __init__(
555
+ self,
556
+ rates: list = [],
557
+ periods: list = [2, 3, 5, 7, 11],
558
+ fft_sizes: list = [2048, 1024, 512],
559
+ sample_rate: int = 44100,
560
+ ):
561
+ """Discriminator combining MPD, MSD, MRD and CQT.
562
+
563
+ Parameters
564
+ ----------
565
+ rates : list, optional
566
+ Sampling rates for MSD, by default []
567
+ periods : list, optional
568
+ Periods for MPD, by default [2, 3, 5, 7, 11]
569
+ fft_sizes : list, optional
570
+ FFT sizes for MRD, by default [2048, 1024, 512]
571
+ sample_rate : int, optional
572
+ Sampling rate of audio in Hz, by default 44100
573
+ """
574
+ super().__init__()
575
+ discs = []
576
+ # Time-domain discriminators
577
+ discs += [MPD(p) for p in periods]
578
+ discs += [MSD(r, sample_rate=sample_rate) for r in rates]
579
+
580
+ # Frequency-domain discriminators (both STFT and CQT)
581
+ discs += [MRD(f, sample_rate=sample_rate) for f in fft_sizes]
582
+ discs += [MultiScaleSubbandCQT(sample_rate=sample_rate)]
583
+
584
+ self.discriminators = nn.ModuleList(discs)
585
+
586
+ def preprocess(self, y):
587
+ # Remove DC offset
588
+ y = y - y.mean(dim=-1, keepdims=True)
589
+ # Peak normalize
590
+ y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
591
+ return y
592
+
593
+ def forward(self, x):
594
+ x = self.preprocess(x)
595
+ fmaps = [d(x) for d in self.discriminators]
596
+ return fmaps
higgs_audio_tokenizer.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on code from: https://github.com/zhenye234/xcodec
2
+ # Licensed under MIT License
3
+ # Modifications by BosonAI
4
+
5
+ import math
6
+ import os
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from typing import Optional, Union, Sequence
11
+ import numpy as np
12
+ from transformers import AutoModel
13
+ import torchaudio
14
+ import json
15
+ import librosa
16
+ from huggingface_hub import snapshot_download
17
+
18
+ from vector_quantize_pytorch import ResidualFSQ
19
+ from descriptaudiocodec.dac.model import dac as dac2
20
+ from quantization.vq import ResidualVectorQuantizer
21
+ from semantic_module import Encoder, Decoder
22
+
23
+ from transformers import HubertModel
24
+
25
+
26
+ # At the top of higgs_audio_tokenizer.py, after the imports
27
+
28
+ def WNConv1d(*args, **kwargs):
29
+ """Applies weight normalization to a 1D Convolutional layer."""
30
+ return nn.utils.weight_norm(nn.Conv1d(*args, **kwargs))
31
+
32
+ def WNLinear(*args, **kwargs):
33
+ """Applies weight normalization to a Linear layer."""
34
+ return nn.utils.weight_norm(nn.Linear(*args, **kwargs))
35
+
36
+ def init_weights(m):
37
+ """
38
+ Applies Xavier (Glorot) uniform initialization to Conv and Linear layers.
39
+ This is a robust, "classic" initialization scheme.
40
+ """
41
+ if isinstance(m, (nn.Conv1d, nn.Conv2d)):
42
+ # Truncated normal initialization for convolutional layers
43
+ nn.init.trunc_normal_(m.weight, std=0.02)
44
+ if m.bias is not None:
45
+ nn.init.constant_(m.bias, 0)
46
+ elif isinstance(m, nn.Linear):
47
+ # Also apply to linear layers for consistency
48
+ nn.init.trunc_normal_(m.weight, std=0.02)
49
+ if m.bias is not None:
50
+ nn.init.constant_(m.bias, 0)
51
+ elif isinstance(m, nn.Embedding):
52
+ # Initialize the codebook gently as well
53
+ nn.init.trunc_normal_(m.weight, std=0.02)
54
+
55
+
56
+ class EncodedResult:
57
+ def __init__(self, audio_codes):
58
+ self.audio_codes = audio_codes
59
+
60
+ class HiggsAudioFeatureExtractor(nn.Module):
61
+ def __init__(self, sampling_rate=16000):
62
+ super().__init__()
63
+ self.sampling_rate = sampling_rate
64
+
65
+ def forward(self, raw_audio, sampling_rate=16000, return_tensors="pt"):
66
+ audio_signal = torch.tensor(raw_audio)
67
+ audio_signal = audio_signal.unsqueeze(0)
68
+ if len(audio_signal.shape) < 3:
69
+ audio_signal = audio_signal.unsqueeze(0)
70
+ return {"input_values": audio_signal}
71
+
72
+
73
+ class HiggsAudioTokenizer(nn.Module):
74
+ def __init__(
75
+ self,
76
+ n_filters: int = 32,
77
+ D: int = 128,
78
+ target_bandwidths: Sequence[Union[int, float]] = [1, 1.5, 2, 4, 6],
79
+ ratios: Sequence[int] = [8, 5, 4, 2], # downsampling by 320
80
+ sample_rate: int = 16000,
81
+ bins: int = 1024,
82
+ n_q: int = 8,
83
+ codebook_dim: int = None,
84
+ normalize: bool = False,
85
+ causal: bool = False,
86
+ semantic_techer: str = "hubert_base_general",
87
+ last_layer_semantic: bool = True,
88
+ merge_mode: str = "concat",
89
+ downsample_mode: str = "step_down",
90
+ semantic_mode: str = "classic",
91
+ vq_scale: int = 1,
92
+ semantic_sample_rate: int = None,
93
+ device: str = "cuda",
94
+ ):
95
+ super().__init__()
96
+ self.hop_length = np.prod(ratios)
97
+ self.semantic_techer = semantic_techer
98
+
99
+ self.frame_rate = math.ceil(sample_rate / np.prod(ratios)) # 50 Hz
100
+
101
+ self.target_bandwidths = target_bandwidths
102
+ self.n_q = n_q
103
+ self.sample_rate = sample_rate
104
+ self.encoder = dac2.Encoder(64, ratios, D)
105
+
106
+ self.decoder_2 = dac2.Decoder(D, 1024, ratios)
107
+ self.last_layer_semantic = last_layer_semantic
108
+ self.device = device
109
+ if semantic_techer == "hubert_base":
110
+ self.semantic_model = AutoModel.from_pretrained("facebook/hubert-base-ls960")
111
+ self.semantic_sample_rate = 16000
112
+ self.semantic_dim = 768
113
+ self.encoder_semantic_dim = 768
114
+
115
+ elif semantic_techer == "wavlm_base_plus":
116
+ self.semantic_model = AutoModel.from_pretrained("microsoft/wavlm-base-plus")
117
+ self.semantic_sample_rate = 16000
118
+ self.semantic_dim = 768
119
+ self.encoder_semantic_dim = 768
120
+
121
+ elif semantic_techer == "mHubert_base":
122
+ self.semantic_model = AutoModel.from_pretrained("utter-project/mHuBERT-147")
123
+ self.semantic_sample_rate = 16000
124
+ self.semantic_dim = 768
125
+ self.encoder_semantic_dim = 768
126
+
127
+ elif semantic_techer == "hubert_base_general":
128
+ self.semantic_model = HubertModel.from_pretrained("/home/ubuntu/.cache/huggingface/hub/models--bosonai--hubert_base/snapshots/b4b85f1652c16ad63fdc818221b215b79ff55934", trust_remote_code=False)
129
+ self.semantic_sample_rate = 16000
130
+ self.semantic_dim = 768
131
+ self.encoder_semantic_dim = 768
132
+
133
+ # Overwrite semantic model sr to ensure semantic_downsample_factor is an integer
134
+ if semantic_sample_rate is not None:
135
+ self.semantic_sample_rate = semantic_sample_rate
136
+
137
+ self.semantic_model.eval()
138
+
139
+ # make the semantic model parameters do not need gradient
140
+ for param in self.semantic_model.parameters():
141
+ param.requires_grad = False
142
+
143
+ self.semantic_downsample_factor = int(self.hop_length / (self.sample_rate / self.semantic_sample_rate) / 320)
144
+
145
+ self.quantizer_dim = int((D + self.encoder_semantic_dim) // vq_scale)
146
+ self.encoder_semantic = Encoder(input_channels=self.semantic_dim, encode_channels=self.encoder_semantic_dim)
147
+ self.decoder_semantic = Decoder(
148
+ code_dim=self.encoder_semantic_dim, output_channels=self.semantic_dim, decode_channels=self.semantic_dim
149
+ )
150
+
151
+ # out_D=D+768
152
+ if isinstance(bins, int): # RVQ
153
+ self.quantizer = ResidualVectorQuantizer(
154
+ dimension=self.quantizer_dim, codebook_dim=codebook_dim, n_q=n_q, bins=bins
155
+ )
156
+ self.quantizer_type = "RVQ"
157
+ else: # RFSQ
158
+ self.quantizer = ResidualFSQ(dim=self.quantizer_dim, levels=bins, num_quantizers=n_q)
159
+ self.quantizer_type = "RFSQ"
160
+
161
+ # self.fc_prior = nn.Linear(D + self.encoder_semantic_dim, self.quantizer_dim)
162
+ # self.fc_post1 = nn.Linear(self.quantizer_dim, self.encoder_semantic_dim)
163
+ # self.fc_post2 = nn.Linear(self.quantizer_dim, D)
164
+
165
+
166
+ self.fc_prior = WNLinear(D + self.encoder_semantic_dim, self.quantizer_dim)
167
+ self.fc_post1 = WNLinear(self.quantizer_dim, self.encoder_semantic_dim)
168
+ self.fc_post2 = WNLinear(self.quantizer_dim, D)
169
+
170
+
171
+ self.downsample_mode = downsample_mode
172
+ if downsample_mode == "avg":
173
+ self.semantic_pooling = nn.AvgPool1d(
174
+ kernel_size=self.semantic_downsample_factor, stride=self.semantic_downsample_factor
175
+ )
176
+
177
+ self.audio_tokenizer_feature_extractor = HiggsAudioFeatureExtractor(sampling_rate=self.sample_rate)
178
+
179
+ self.apply(init_weights)
180
+
181
+ @property
182
+ def tps(self):
183
+ return self.frame_rate
184
+
185
+ @property
186
+ def sampling_rate(self):
187
+ return self.sample_rate
188
+
189
+ @property
190
+ def num_codebooks(self):
191
+ return self.n_q
192
+
193
+ @property
194
+ def codebook_size(self):
195
+ return self.quantizer_dim
196
+
197
+ def get_last_layer(self):
198
+ return self.decoder.layers[-1].weight
199
+
200
+ def calculate_rec_loss(self, rec, target):
201
+ target = target / target.norm(dim=-1, keepdim=True)
202
+ rec = rec / rec.norm(dim=-1, keepdim=True)
203
+ rec_loss = (1 - (target * rec).sum(-1)).mean()
204
+
205
+ return rec_loss
206
+
207
+ @torch.no_grad()
208
+ def get_regress_target(self, x):
209
+ x = torchaudio.functional.resample(x, self.sample_rate, self.semantic_sample_rate)
210
+
211
+ if (
212
+ self.semantic_techer == "hubert_base"
213
+ or self.semantic_techer == "hubert_base_general"
214
+ or self.semantic_techer == "wavlm_base_plus"
215
+ ):
216
+ x = x[:, 0, :]
217
+ x = F.pad(x, (160, 160))
218
+ target = self.semantic_model(x, output_hidden_states=True).hidden_states
219
+ target = torch.stack(target, dim=1) # .transpose(-1, -2)#.flatten(start_dim=1, end_dim=2)
220
+
221
+ # average for all layers
222
+ target = target.mean(1)
223
+ # target = target[9]
224
+ # if self.hop_length > 320:
225
+ # target = self.semantic_pooling(target.transpose(1, 2)).transpose(1, 2)
226
+
227
+ elif self.semantic_techer == "w2v_bert2":
228
+ target = self.semantic_model(x)
229
+
230
+ elif self.semantic_techer.startswith("whisper"):
231
+ if self.last_layer_semantic:
232
+ target = self.semantic_model(x, avg_layers=False)
233
+ else:
234
+ target = self.semantic_model(x, avg_layers=True)
235
+
236
+ elif self.semantic_techer.startswith("mert_music"):
237
+ if self.last_layer_semantic:
238
+ target = self.semantic_model(x, avg_layers=False)
239
+ else:
240
+ target = self.semantic_model(x, avg_layers=True)
241
+
242
+ elif self.semantic_techer.startswith("qwen_audio_omni"):
243
+ target = self.semantic_model(x)
244
+
245
+ if self.downsample_mode == "step_down":
246
+ if self.semantic_downsample_factor > 1:
247
+ target = target[:, :: self.semantic_downsample_factor, :]
248
+
249
+ elif self.downsample_mode == "avg":
250
+ target = self.semantic_pooling(target.transpose(1, 2)).transpose(1, 2)
251
+ return target
252
+
253
+ def forward(self, x: torch.Tensor, bw: int):
254
+ e_semantic_input = self.get_regress_target(x).detach()
255
+
256
+ e_semantic = self.encoder_semantic(e_semantic_input.transpose(1, 2))
257
+ e_acoustic = self.encoder(x)
258
+
259
+ e = torch.cat([e_acoustic, e_semantic], dim=1)
260
+
261
+ e = self.fc_prior(e.transpose(1, 2))
262
+
263
+ if self.quantizer_type == "RVQ":
264
+ e = e.transpose(1, 2)
265
+ quantized, codes, bandwidth, commit_loss = self.quantizer(e, self.frame_rate, bw)
266
+ quantized = quantized.transpose(1, 2)
267
+ else:
268
+ quantized, codes = self.quantizer(e)
269
+ commit_loss = torch.tensor(0.0)
270
+
271
+ quantized_semantic = self.fc_post1(quantized).transpose(1, 2)
272
+ quantized_acoustic = self.fc_post2(quantized).transpose(1, 2)
273
+
274
+ o = self.decoder_2(quantized_acoustic)
275
+
276
+ o_semantic = self.decoder_semantic(quantized_semantic)
277
+ semantic_recon_loss = F.mse_loss(e_semantic_input.transpose(1, 2).detach(), o_semantic)
278
+
279
+ return o, commit_loss, semantic_recon_loss, None
280
+
281
+ def encode(self, audio_path_or_wv, sr=None, loudness_normalize=False, loudness_threshold=-23.0):
282
+ if isinstance(audio_path_or_wv, str):
283
+ wv, sr = librosa.load(audio_path_or_wv, mono=True, sr=None)
284
+ else:
285
+ wv = audio_path_or_wv
286
+ assert sr is not None
287
+ if loudness_normalize:
288
+ import pyloudnorm as pyln
289
+
290
+ meter = pyln.Meter(sr)
291
+ l = meter.integrated_loudness(wv)
292
+ wv = pyln.normalize.loudness(wv, l, loudness_threshold)
293
+ if sr != self.sampling_rate:
294
+ wv = librosa.resample(wv, orig_sr=sr, target_sr=self.sampling_rate)
295
+ if self.audio_tokenizer_feature_extractor is not None:
296
+ inputs = self.audio_tokenizer_feature_extractor(
297
+ raw_audio=wv, sampling_rate=self.audio_tokenizer_feature_extractor.sampling_rate, return_tensors="pt"
298
+ )
299
+ input_values = inputs["input_values"].to(self.device)
300
+ else:
301
+ input_values = torch.from_numpy(wv).float().unsqueeze(0)
302
+ with torch.no_grad():
303
+ encoder_outputs = self._xcodec_encode(input_values)
304
+ vq_code = encoder_outputs.audio_codes[0]
305
+ return vq_code
306
+
307
+
308
+
309
+ def _xcodec_encode(self, x: torch.Tensor, target_bw: Optional[int] = None) -> torch.Tensor:
310
+ bw = target_bw
311
+
312
+ e_semantic_input = self.get_regress_target(x).detach()
313
+
314
+ e_semantic = self.encoder_semantic(e_semantic_input.transpose(1, 2))
315
+ e_acoustic = self.encoder(x)
316
+
317
+ if e_acoustic.shape[2] != e_semantic.shape[2]:
318
+ pad_size = 160 * self.semantic_downsample_factor
319
+ e_acoustic = self.encoder(F.pad(x[:, 0, :], (pad_size, pad_size)).unsqueeze(0))
320
+
321
+ if e_acoustic.shape[2] != e_semantic.shape[2]:
322
+ if e_acoustic.shape[2] > e_semantic.shape[2]:
323
+ e_acoustic = e_acoustic[:, :, : e_semantic.shape[2]]
324
+ else:
325
+ e_semantic = e_semantic[:, :, : e_acoustic.shape[2]]
326
+
327
+ e = torch.cat([e_acoustic, e_semantic], dim=1)
328
+
329
+ e = self.fc_prior(e.transpose(1, 2))
330
+
331
+ if self.quantizer_type == "RVQ":
332
+ e = e.transpose(1, 2)
333
+ quantized, codes, bandwidth, commit_loss = self.quantizer(e, self.frame_rate, bw)
334
+ codes = codes.permute(1, 0, 2)
335
+ else:
336
+ quantized, codes = self.quantizer(e)
337
+ codes = codes.permute(0, 2, 1)
338
+
339
+ # return codes
340
+ return EncodedResult(codes)
341
+
342
+ def decode(self, vq_code: torch.Tensor) -> torch.Tensor:
343
+ if self.quantizer_type == "RVQ":
344
+ vq_code = vq_code.permute(1, 0, 2)
345
+ quantized = self.quantizer.decode(vq_code)
346
+ quantized = quantized.transpose(1, 2)
347
+ else:
348
+ vq_code = vq_code.permute(0, 2, 1)
349
+ quantized = self.quantizer.get_output_from_indices(vq_code)
350
+ quantized_acoustic = self.fc_post2(quantized).transpose(1, 2)
351
+
352
+ o = self.decoder_2(quantized_acoustic)
353
+ return o.cpu().numpy()
354
+
355
+
356
+ def load_higgs_audio_tokenizer(tokenizer_name_or_path, device="cuda"):
357
+ is_local = os.path.exists(tokenizer_name_or_path)
358
+ if not is_local:
359
+ tokenizer_path = snapshot_download(tokenizer_name_or_path)
360
+ else:
361
+ tokenizer_path = tokenizer_name_or_path
362
+ config_path = os.path.join(tokenizer_path, "config.json")
363
+ model_path = os.path.join(tokenizer_path, "model.pth")
364
+ config = json.load(open(config_path))
365
+ model = HiggsAudioTokenizer(
366
+ **config,
367
+ device=device,
368
+ )
369
+ parameter_dict = torch.load(model_path, map_location=device, weights_only=False)
370
+ model.load_state_dict(parameter_dict, strict=False)
371
+ model.to(device)
372
+ model.eval()
373
+ return model
loss.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from audiotools import AudioSignal
4
+ from audiotools import STFTParams
5
+ from torch import nn
6
+ import typing
7
+ from typing import List
8
+
9
+ class L1Loss(nn.L1Loss):
10
+ """L1 Loss between AudioSignals. Defaults
11
+ to comparing ``audio_data``, but any
12
+ attribute of an AudioSignal can be used.
13
+
14
+ Parameters
15
+ ----------
16
+ attribute : str, optional
17
+ Attribute of signal to compare, defaults to ``audio_data``.
18
+ weight : float, optional
19
+ Weight of this loss, defaults to 1.0.
20
+
21
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
22
+ """
23
+
24
+ def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs):
25
+ self.attribute = attribute
26
+ self.weight = weight
27
+ super().__init__(**kwargs)
28
+
29
+ def forward(self, x: AudioSignal, y: AudioSignal):
30
+ """
31
+ Parameters
32
+ ----------
33
+ x : AudioSignal
34
+ Estimate AudioSignal
35
+ y : AudioSignal
36
+ Reference AudioSignal
37
+
38
+ Returns
39
+ -------
40
+ torch.Tensor
41
+ L1 loss between AudioSignal attributes.
42
+ """
43
+ if isinstance(x, AudioSignal):
44
+ x = getattr(x, self.attribute)
45
+ y = getattr(y, self.attribute)
46
+ return super().forward(x, y)
47
+
48
+
49
+ class SISDRLoss(nn.Module):
50
+ """
51
+ Computes the Scale-Invariant Source-to-Distortion Ratio between a batch
52
+ of estimated and reference audio signals or aligned features.
53
+
54
+ Parameters
55
+ ----------
56
+ scaling : int, optional
57
+ Whether to use scale-invariant (True) or
58
+ signal-to-noise ratio (False), by default True
59
+ reduction : str, optional
60
+ How to reduce across the batch (either 'mean',
61
+ 'sum', or none).], by default ' mean'
62
+ zero_mean : int, optional
63
+ Zero mean the references and estimates before
64
+ computing the loss, by default True
65
+ clip_min : int, optional
66
+ The minimum possible loss value. Helps network
67
+ to not focus on making already good examples better, by default None
68
+ weight : float, optional
69
+ Weight of this loss, defaults to 1.0.
70
+
71
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
72
+ """
73
+
74
+ def __init__(
75
+ self,
76
+ scaling: int = True,
77
+ reduction: str = "mean",
78
+ zero_mean: int = True,
79
+ clip_min: int = None,
80
+ weight: float = 1.0,
81
+ ):
82
+ self.scaling = scaling
83
+ self.reduction = reduction
84
+ self.zero_mean = zero_mean
85
+ self.clip_min = clip_min
86
+ self.weight = weight
87
+ super().__init__()
88
+
89
+ def forward(self, x: AudioSignal, y: AudioSignal):
90
+ eps = 1e-8
91
+ # nb, nc, nt
92
+ if isinstance(x, AudioSignal):
93
+ references = x.audio_data
94
+ estimates = y.audio_data
95
+ else:
96
+ references = x
97
+ estimates = y
98
+
99
+ nb = references.shape[0]
100
+ references = references.reshape(nb, 1, -1).permute(0, 2, 1)
101
+ estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1)
102
+
103
+ # samples now on axis 1
104
+ if self.zero_mean:
105
+ mean_reference = references.mean(dim=1, keepdim=True)
106
+ mean_estimate = estimates.mean(dim=1, keepdim=True)
107
+ else:
108
+ mean_reference = 0
109
+ mean_estimate = 0
110
+
111
+ _references = references - mean_reference
112
+ _estimates = estimates - mean_estimate
113
+
114
+ references_projection = (_references**2).sum(dim=-2) + eps
115
+ references_on_estimates = (_estimates * _references).sum(dim=-2) + eps
116
+
117
+ scale = (
118
+ (references_on_estimates / references_projection).unsqueeze(1)
119
+ if self.scaling
120
+ else 1
121
+ )
122
+
123
+ e_true = scale * _references
124
+ e_res = _estimates - e_true
125
+
126
+ signal = (e_true**2).sum(dim=1)
127
+ noise = (e_res**2).sum(dim=1)
128
+ sdr = -10 * torch.log10(signal / noise + eps)
129
+
130
+ if self.clip_min is not None:
131
+ sdr = torch.clamp(sdr, min=self.clip_min)
132
+
133
+ if self.reduction == "mean":
134
+ sdr = sdr.mean()
135
+ elif self.reduction == "sum":
136
+ sdr = sdr.sum()
137
+ return sdr
138
+
139
+
140
+ class MultiScaleSTFTLoss(nn.Module):
141
+ """Computes the multi-scale STFT loss from [1].
142
+
143
+ Parameters
144
+ ----------
145
+ window_lengths : List[int], optional
146
+ Length of each window of each STFT, by default [2048, 512]
147
+ loss_fn : typing.Callable, optional
148
+ How to compare each loss, by default nn.L1Loss()
149
+ clamp_eps : float, optional
150
+ Clamp on the log magnitude, below, by default 1e-5
151
+ mag_weight : float, optional
152
+ Weight of raw magnitude portion of loss, by default 1.0
153
+ log_weight : float, optional
154
+ Weight of log magnitude portion of loss, by default 1.0
155
+ pow : float, optional
156
+ Power to raise magnitude to before taking log, by default 2.0
157
+ weight : float, optional
158
+ Weight of this loss, by default 1.0
159
+ match_stride : bool, optional
160
+ Whether to match the stride of convolutional layers, by default False
161
+
162
+ References
163
+ ----------
164
+
165
+ 1. Engel, Jesse, Chenjie Gu, and Adam Roberts.
166
+ "DDSP: Differentiable Digital Signal Processing."
167
+ International Conference on Learning Representations. 2019.
168
+
169
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
170
+ """
171
+
172
+ def __init__(
173
+ self,
174
+ window_lengths: List[int] = [2048, 512],
175
+ loss_fn: typing.Callable = nn.L1Loss(),
176
+ clamp_eps: float = 1e-5,
177
+ mag_weight: float = 1.0,
178
+ log_weight: float = 1.0,
179
+ pow: float = 2.0,
180
+ weight: float = 1.0,
181
+ match_stride: bool = False,
182
+ window_type: str = None,
183
+ ):
184
+ super().__init__()
185
+ self.stft_params = [
186
+ STFTParams(
187
+ window_length=w,
188
+ hop_length=w // 4,
189
+ match_stride=match_stride,
190
+ window_type=window_type,
191
+ )
192
+ for w in window_lengths
193
+ ]
194
+ self.loss_fn = loss_fn
195
+ self.log_weight = log_weight
196
+ self.mag_weight = mag_weight
197
+ self.clamp_eps = clamp_eps
198
+ self.weight = weight
199
+ self.pow = pow
200
+
201
+ def forward(self, x: AudioSignal, y: AudioSignal):
202
+ """Computes multi-scale STFT between an estimate and a reference
203
+ signal.
204
+
205
+ Parameters
206
+ ----------
207
+ x : AudioSignal
208
+ Estimate signal
209
+ y : AudioSignal
210
+ Reference signal
211
+
212
+ Returns
213
+ -------
214
+ torch.Tensor
215
+ Multi-scale STFT loss.
216
+ """
217
+ loss = 0.0
218
+ for s in self.stft_params:
219
+ x.stft(s.window_length, s.hop_length, s.window_type)
220
+ y.stft(s.window_length, s.hop_length, s.window_type)
221
+ loss += self.log_weight * self.loss_fn(
222
+ x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
223
+ y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
224
+ )
225
+ loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude)
226
+ return loss
227
+
228
+
229
+ class MelSpectrogramLoss(nn.Module):
230
+ """Compute distance between mel spectrograms. Can be used
231
+ in a multi-scale way.
232
+
233
+ Parameters
234
+ ----------
235
+ n_mels : List[int]
236
+ Number of mels per STFT, by default [150, 80],
237
+ window_lengths : List[int], optional
238
+ Length of each window of each STFT, by default [2048, 512]
239
+ loss_fn : typing.Callable, optional
240
+ How to compare each loss, by default nn.L1Loss()
241
+ clamp_eps : float, optional
242
+ Clamp on the log magnitude, below, by default 1e-5
243
+ mag_weight : float, optional
244
+ Weight of raw magnitude portion of loss, by default 1.0
245
+ log_weight : float, optional
246
+ Weight of log magnitude portion of loss, by default 1.0
247
+ pow : float, optional
248
+ Power to raise magnitude to before taking log, by default 2.0
249
+ weight : float, optional
250
+ Weight of this loss, by default 1.0
251
+ match_stride : bool, optional
252
+ Whether to match the stride of convolutional layers, by default False
253
+
254
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
255
+ """
256
+
257
+ def __init__(
258
+ self,
259
+ n_mels: List[int] = [150, 80],
260
+ window_lengths: List[int] = [2048, 512],
261
+ loss_fn: typing.Callable = nn.L1Loss(),
262
+ clamp_eps: float = 1e-5,
263
+ mag_weight: float = 1.0,
264
+ log_weight: float = 1.0,
265
+ pow: float = 2.0,
266
+ weight: float = 1.0,
267
+ match_stride: bool = False,
268
+ mel_fmin: List[float] = [0.0, 0.0],
269
+ mel_fmax: List[float] = [None, None],
270
+ window_type: str = None,
271
+ ):
272
+ super().__init__()
273
+ self.stft_params = [
274
+ STFTParams(
275
+ window_length=w,
276
+ hop_length=w // 4,
277
+ match_stride=match_stride,
278
+ window_type=window_type,
279
+ )
280
+ for w in window_lengths
281
+ ]
282
+ self.n_mels = n_mels
283
+ self.loss_fn = loss_fn
284
+ self.clamp_eps = clamp_eps
285
+ self.log_weight = log_weight
286
+ self.mag_weight = mag_weight
287
+ self.weight = weight
288
+ self.mel_fmin = mel_fmin
289
+ self.mel_fmax = mel_fmax
290
+ self.pow = pow
291
+
292
+ def forward(self, x: AudioSignal, y: AudioSignal):
293
+ """Computes mel loss between an estimate and a reference
294
+ signal.
295
+
296
+ Parameters
297
+ ----------
298
+ x : AudioSignal
299
+ Estimate signal
300
+ y : AudioSignal
301
+ Reference signal
302
+
303
+ Returns
304
+ -------
305
+ torch.Tensor
306
+ Mel loss.
307
+ """
308
+ loss = 0.0
309
+ for n_mels, fmin, fmax, s in zip(
310
+ self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
311
+ ):
312
+ kwargs = {
313
+ "window_length": s.window_length,
314
+ "hop_length": s.hop_length,
315
+ "window_type": s.window_type,
316
+ }
317
+ x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
318
+ y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
319
+
320
+ loss += self.log_weight * self.loss_fn(
321
+ x_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
322
+ y_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
323
+ )
324
+ loss += self.mag_weight * self.loss_fn(x_mels, y_mels)
325
+ return loss
326
+
327
+
328
+ class GANLoss(nn.Module):
329
+ """
330
+ Computes a discriminator loss, given a discriminator on
331
+ generated waveforms/spectrograms compared to ground truth
332
+ waveforms/spectrograms. Computes the loss for both the
333
+ discriminator and the generator in separate functions.
334
+ """
335
+
336
+ def __init__(self, discriminator):
337
+ super().__init__()
338
+ self.discriminator = discriminator
339
+
340
+ def forward(self, fake, real):
341
+ d_fake = self.discriminator(fake.audio_data)
342
+ d_real = self.discriminator(real.audio_data)
343
+ return d_fake, d_real
344
+
345
+ def discriminator_loss(self, fake, real):
346
+ d_fake, d_real = self.forward(fake.clone().detach(), real)
347
+
348
+ loss_d = 0
349
+ for x_fake, x_real in zip(d_fake, d_real):
350
+ loss_d += torch.mean(x_fake[-1] ** 2)
351
+ loss_d += torch.mean((1 - x_real[-1]) ** 2)
352
+ return loss_d
353
+
354
+ def generator_loss(self, fake, real):
355
+ d_fake, d_real = self.forward(fake, real)
356
+
357
+ loss_g = 0
358
+ for x_fake in d_fake:
359
+ loss_g += torch.mean((1 - x_fake[-1]) ** 2)
360
+
361
+ loss_feature = 0
362
+
363
+ for i in range(len(d_fake)):
364
+ for j in range(len(d_fake[i]) - 1):
365
+ loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
366
+ return loss_g, loss_feature
367
+
368
+
outputs/logs/250801-104649/events.out.tfevents.1754045209.192-222-50-191.575849.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:adf08230fc10f50d2a7fcb4c2deaf1b2cbb45116b6d44f1cce5eb365471f516b
3
+ size 657
outputs/logs/250801-104824/events.out.tfevents.1754045304.192-222-50-191.577752.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:85e63a92198fbc1d513003ee532294feaedf87505a51e15ec2fc520900ec53c4
3
+ size 657
outputs/logs/250801-104944/events.out.tfevents.1754045384.192-222-50-191.579650.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a4e7f045c676534be11f5ad39b553611caa786d3e1559cbf100ca84133d54a77
3
+ size 88
outputs/logs/250801-105034/events.out.tfevents.1754045434.192-222-50-191.581483.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c5c97f5d3b5b6cc60223a60dceb2b06f6a847e682935f8d45086892f176bd98d
3
+ size 657
outputs/logs/250801-105133/events.out.tfevents.1754045493.192-222-50-191.583409.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b7bae6adc53141720bb44a2fda69e2f97bce51b98365458ec53ce59c66bf255f
3
+ size 5751664
outputs/logs/250801-134657/events.out.tfevents.1754056017.192-222-50-191.688744.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c021361932098e4ba601f59134b8ea6c0192a31d103eb6374070e394bb740060
3
+ size 61388
outputs/logs/250801-135301/events.out.tfevents.1754056381.192-222-50-191.693590.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:61a297a283372869b74886f9fa2a17d11bd77fe284275aae505fbaa28957cda8
3
+ size 88
outputs/logs/250801-135344/events.out.tfevents.1754056424.192-222-50-191.695388.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0954cc521ca9b594ace0fb485ccf8556621e017c08da9d525e4071c94489c3ec
3
+ size 657
outputs/logs/250801-135510/events.out.tfevents.1754056510.192-222-50-191.697490.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1dcf0296a6058bce4cfc81563b9ffaac98d0e171d3f69e15008f179aeb4e5d8e
3
+ size 3419391
outputs/logs/250801-202235/events.out.tfevents.1754079755.192-222-50-191.6026.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:163634220bcd180cfe60f9d73024faeeeaed6c2849ea5d44909e6666e4a4ac54
3
+ size 88
outputs/logs/250801-202320/events.out.tfevents.1754079800.192-222-50-191.6708.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e7088dc97ef965f0fb3bfff3b1f28e1b63644f666b6464a8524235f5bdd68a24
3
+ size 7234042
outputs/logs/250802-065733/events.out.tfevents.1754117853.192-222-50-191.86944.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b13f240f199d42345f13428575a813d039fdd4e8899a915b7b55741b504ea46
3
+ size 208834
outputs/logs/250802-072035/events.out.tfevents.1754119235.192-222-50-191.100690.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:93fc6699e79f38e10ada3f7431617725b9b35daee4cda1978fb73800f7304113
3
+ size 3373
outputs_24/logs/250730-112649/events.out.tfevents.1753874809.192-222-50-191.3556345.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:34843a6a8d0849900de8943c4bbf4dcf77424e605eb1e4a721b8cbf894f1fc73
3
+ size 88
outputs_24/logs/250730-112910/events.out.tfevents.1753874950.192-222-50-191.3557426.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6d9cca204af08430087351d78de76e17d16ad4ed8e2e31540b0857406c77bb86
3
+ size 1990
outputs_24/logs/250730-113135/events.out.tfevents.1753875095.192-222-50-191.3558918.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d0f7180018cd52196492be018bc160ce1222c437e27d4019bd87d5207e9f1fe5
3
+ size 63081
outputs_24/logs/250730-114727/events.out.tfevents.1753876047.192-222-50-191.3567432.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:395a91280aa73bfaf31de37cb66f76ca9916e25d59fa84568e2124ad19162e7a
3
+ size 3808
outputs_24/logs/250730-115006/events.out.tfevents.1753876206.192-222-50-191.3569242.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b47cceec7b3b651af1fa77129099bd8f0fcebfb60ba0950b1a0efeba2ea28a2f
3
+ size 6744012
outputs_24/logs/250730-151325/events.out.tfevents.1753888405.192-222-50-191.3660307.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:69269e593e8fb146793020448eedf00cda64a27f3b3a3cebf47dbc5c7df5e8dc
3
+ size 5976
outputs_24/logs/250730-152054/events.out.tfevents.1753888854.192-222-50-191.3663830.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f87195745b3babee5ab3afbc4b46d85decbf2c39f31376891b3d001f1e1db52
3
+ size 88
outputs_24/logs/250730-152132/events.out.tfevents.1753888892.192-222-50-191.3664702.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7e10c22d4a06fb02d74592b2ac393196affce38724e0c2fa3117dfec6931acf
3
+ size 88
outputs_24/logs/250730-152218/events.out.tfevents.1753888938.192-222-50-191.3665630.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2c9e1081361a51f466892f52660c7126dd4e281ad1e1a86b01ecf914a03318ed
3
+ size 657
outputs_24/logs/250730-152329/events.out.tfevents.1753889009.192-222-50-191.3666743.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fbbfa9f88f5e7790426f8bee52f171526f4d7fb1ab5efb51e25c77581a3afee2
3
+ size 657
outputs_24/logs/250730-152554/events.out.tfevents.1753889154.192-222-50-191.3668339.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7bd5fbc186ff41408a06c8d8ecb91ea3429ab1971c864edb0411ae2391af8965
3
+ size 88
outputs_24/logs/250730-152702/events.out.tfevents.1753889222.192-222-50-191.3669391.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df5419dd06bca60309a60b28ee070c8fc816845de941a0b311ebfbb6777dccac
3
+ size 88
outputs_24/logs/250730-152902/events.out.tfevents.1753889342.192-222-50-191.3671654.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f2377fd533270911cdd47226a5b11222bb0d037b3a47eaac9c66b3ffc605d03
3
+ size 1526378
outputs_24/logs/250730-161025/events.out.tfevents.1753891825.192-222-50-191.3698156.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f3427b4010f32544ebc7c050f6ed85ce380c15bddcc039e3fffa295b4a2b4813
3
+ size 1528786
outputs_24/logs/250730-165034/events.out.tfevents.1753894234.192-222-50-191.3717308.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:242a5126eb7e28c25d6ccf4377c745c970065110502df9dbf74e3392eff098c8
3
+ size 4794
outputs_24/logs/250730-165327/events.out.tfevents.1753894407.192-222-50-191.3719515.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f02d9f09b42a51e0879c5de3e059703520e0aa2c0a1fc329e2878ecd0deb1c23
3
+ size 657