Respair commited on
Commit
ce00803
·
verified ·
1 Parent(s): 239cab4

Update higgs_audio_tokenizer.py

Browse files
Files changed (1) hide show
  1. higgs_audio_tokenizer.py +31 -47
higgs_audio_tokenizer.py CHANGED
@@ -23,33 +23,26 @@ from semantic_module import Encoder, Decoder
23
  from transformers import HubertModel
24
 
25
 
26
- # At the top of higgs_audio_tokenizer.py, after the imports
27
 
28
  def WNConv1d(*args, **kwargs):
29
- """Applies weight normalization to a 1D Convolutional layer."""
30
  return nn.utils.weight_norm(nn.Conv1d(*args, **kwargs))
31
 
32
  def WNLinear(*args, **kwargs):
33
- """Applies weight normalization to a Linear layer."""
34
  return nn.utils.weight_norm(nn.Linear(*args, **kwargs))
35
 
36
  def init_weights(m):
37
- """
38
- Applies Xavier (Glorot) uniform initialization to Conv and Linear layers.
39
- This is a robust, "classic" initialization scheme.
40
- """
41
  if isinstance(m, (nn.Conv1d, nn.Conv2d)):
42
- # Truncated normal initialization for convolutional layers
43
  nn.init.trunc_normal_(m.weight, std=0.02)
44
  if m.bias is not None:
45
  nn.init.constant_(m.bias, 0)
46
  elif isinstance(m, nn.Linear):
47
- # Also apply to linear layers for consistency
48
  nn.init.trunc_normal_(m.weight, std=0.02)
49
  if m.bias is not None:
50
  nn.init.constant_(m.bias, 0)
51
  elif isinstance(m, nn.Embedding):
52
- # Initialize the codebook gently as well
53
  nn.init.trunc_normal_(m.weight, std=0.02)
54
 
55
 
@@ -76,7 +69,7 @@ class HiggsAudioTokenizer(nn.Module):
76
  n_filters: int = 32,
77
  D: int = 128,
78
  target_bandwidths: Sequence[Union[int, float]] = [1, 1.5, 2, 4, 6],
79
- ratios: Sequence[int] = [8, 5, 4, 2], # downsampling by 320
80
  sample_rate: int = 16000,
81
  bins: int = 1024,
82
  n_q: int = 8,
@@ -96,7 +89,7 @@ class HiggsAudioTokenizer(nn.Module):
96
  self.hop_length = np.prod(ratios)
97
  self.semantic_techer = semantic_techer
98
 
99
- self.frame_rate = math.ceil(sample_rate / np.prod(ratios)) # 50 Hz
100
 
101
  self.target_bandwidths = target_bandwidths
102
  self.n_q = n_q
@@ -106,6 +99,8 @@ class HiggsAudioTokenizer(nn.Module):
106
  self.decoder_2 = dac2.Decoder(D, 1024, ratios)
107
  self.last_layer_semantic = last_layer_semantic
108
  self.device = device
 
 
109
  if semantic_techer == "hubert_base":
110
  self.semantic_model = AutoModel.from_pretrained("facebook/hubert-base-ls960")
111
  self.semantic_sample_rate = 16000
@@ -125,18 +120,16 @@ class HiggsAudioTokenizer(nn.Module):
125
  self.encoder_semantic_dim = 768
126
 
127
  elif semantic_techer == "hubert_base_general":
128
- self.semantic_model = HubertModel.from_pretrained("/home/ubuntu/.cache/huggingface/hub/models--bosonai--hubert_base/snapshots/b4b85f1652c16ad63fdc818221b215b79ff55934", trust_remote_code=False)
129
  self.semantic_sample_rate = 16000
130
  self.semantic_dim = 768
131
  self.encoder_semantic_dim = 768
132
 
133
- # Overwrite semantic model sr to ensure semantic_downsample_factor is an integer
134
  if semantic_sample_rate is not None:
135
  self.semantic_sample_rate = semantic_sample_rate
136
 
137
  self.semantic_model.eval()
138
 
139
- # make the semantic model parameters do not need gradient
140
  for param in self.semantic_model.parameters():
141
  param.requires_grad = False
142
 
@@ -148,20 +141,15 @@ class HiggsAudioTokenizer(nn.Module):
148
  code_dim=self.encoder_semantic_dim, output_channels=self.semantic_dim, decode_channels=self.semantic_dim
149
  )
150
 
151
- # out_D=D+768
152
- if isinstance(bins, int): # RVQ
153
  self.quantizer = ResidualVectorQuantizer(
154
  dimension=self.quantizer_dim, codebook_dim=codebook_dim, n_q=n_q, bins=bins
155
  )
156
  self.quantizer_type = "RVQ"
157
- else: # RFSQ
158
  self.quantizer = ResidualFSQ(dim=self.quantizer_dim, levels=bins, num_quantizers=n_q)
159
  self.quantizer_type = "RFSQ"
160
 
161
- # self.fc_prior = nn.Linear(D + self.encoder_semantic_dim, self.quantizer_dim)
162
- # self.fc_post1 = nn.Linear(self.quantizer_dim, self.encoder_semantic_dim)
163
- # self.fc_post2 = nn.Linear(self.quantizer_dim, D)
164
-
165
 
166
  self.fc_prior = WNLinear(D + self.encoder_semantic_dim, self.quantizer_dim)
167
  self.fc_post1 = WNLinear(self.quantizer_dim, self.encoder_semantic_dim)
@@ -212,17 +200,14 @@ class HiggsAudioTokenizer(nn.Module):
212
  self.semantic_techer == "hubert_base"
213
  or self.semantic_techer == "hubert_base_general"
214
  or self.semantic_techer == "wavlm_base_plus"
 
215
  ):
216
  x = x[:, 0, :]
217
  x = F.pad(x, (160, 160))
218
  target = self.semantic_model(x, output_hidden_states=True).hidden_states
219
- target = torch.stack(target, dim=1) # .transpose(-1, -2)#.flatten(start_dim=1, end_dim=2)
220
 
221
- # average for all layers
222
  target = target.mean(1)
223
- # target = target[9]
224
- # if self.hop_length > 320:
225
- # target = self.semantic_pooling(target.transpose(1, 2)).transpose(1, 2)
226
 
227
  elif self.semantic_techer == "w2v_bert2":
228
  target = self.semantic_model(x)
@@ -278,7 +263,7 @@ class HiggsAudioTokenizer(nn.Module):
278
 
279
  return o, commit_loss, semantic_recon_loss, None
280
 
281
- def encode(self, audio_path_or_wv, sr=None, loudness_normalize=False, loudness_threshold=-23.0):
282
  if isinstance(audio_path_or_wv, str):
283
  wv, sr = librosa.load(audio_path_or_wv, mono=True, sr=None)
284
  else:
@@ -336,7 +321,6 @@ class HiggsAudioTokenizer(nn.Module):
336
  quantized, codes = self.quantizer(e)
337
  codes = codes.permute(0, 2, 1)
338
 
339
- # return codes
340
  return EncodedResult(codes)
341
 
342
  def decode(self, vq_code: torch.Tensor) -> torch.Tensor:
@@ -353,21 +337,21 @@ class HiggsAudioTokenizer(nn.Module):
353
  return o.cpu().numpy()
354
 
355
 
356
- def load_higgs_audio_tokenizer(tokenizer_name_or_path, device="cuda"):
357
- is_local = os.path.exists(tokenizer_name_or_path)
358
- if not is_local:
359
- tokenizer_path = snapshot_download(tokenizer_name_or_path)
360
- else:
361
- tokenizer_path = tokenizer_name_or_path
362
- config_path = os.path.join(tokenizer_path, "config.json")
363
- model_path = os.path.join(tokenizer_path, "model.pth")
364
- config = json.load(open(config_path))
365
- model = HiggsAudioTokenizer(
366
- **config,
367
- device=device,
368
- )
369
- parameter_dict = torch.load(model_path, map_location=device, weights_only=False)
370
- model.load_state_dict(parameter_dict, strict=False)
371
- model.to(device)
372
- model.eval()
373
- return model
 
23
  from transformers import HubertModel
24
 
25
 
 
26
 
27
  def WNConv1d(*args, **kwargs):
28
+
29
  return nn.utils.weight_norm(nn.Conv1d(*args, **kwargs))
30
 
31
  def WNLinear(*args, **kwargs):
32
+
33
  return nn.utils.weight_norm(nn.Linear(*args, **kwargs))
34
 
35
  def init_weights(m):
36
+
 
 
 
37
  if isinstance(m, (nn.Conv1d, nn.Conv2d)):
 
38
  nn.init.trunc_normal_(m.weight, std=0.02)
39
  if m.bias is not None:
40
  nn.init.constant_(m.bias, 0)
41
  elif isinstance(m, nn.Linear):
 
42
  nn.init.trunc_normal_(m.weight, std=0.02)
43
  if m.bias is not None:
44
  nn.init.constant_(m.bias, 0)
45
  elif isinstance(m, nn.Embedding):
 
46
  nn.init.trunc_normal_(m.weight, std=0.02)
47
 
48
 
 
69
  n_filters: int = 32,
70
  D: int = 128,
71
  target_bandwidths: Sequence[Union[int, float]] = [1, 1.5, 2, 4, 6],
72
+ ratios: Sequence[int] = [8, 5, 4, 2],
73
  sample_rate: int = 16000,
74
  bins: int = 1024,
75
  n_q: int = 8,
 
89
  self.hop_length = np.prod(ratios)
90
  self.semantic_techer = semantic_techer
91
 
92
+ self.frame_rate = math.ceil(sample_rate / np.prod(ratios))
93
 
94
  self.target_bandwidths = target_bandwidths
95
  self.n_q = n_q
 
99
  self.decoder_2 = dac2.Decoder(D, 1024, ratios)
100
  self.last_layer_semantic = last_layer_semantic
101
  self.device = device
102
+
103
+
104
  if semantic_techer == "hubert_base":
105
  self.semantic_model = AutoModel.from_pretrained("facebook/hubert-base-ls960")
106
  self.semantic_sample_rate = 16000
 
120
  self.encoder_semantic_dim = 768
121
 
122
  elif semantic_techer == "hubert_base_general":
123
+ self.semantic_model = HubertModel.from_pretrained("bosonai/hubert_base", trust_remote_code=False)
124
  self.semantic_sample_rate = 16000
125
  self.semantic_dim = 768
126
  self.encoder_semantic_dim = 768
127
 
 
128
  if semantic_sample_rate is not None:
129
  self.semantic_sample_rate = semantic_sample_rate
130
 
131
  self.semantic_model.eval()
132
 
 
133
  for param in self.semantic_model.parameters():
134
  param.requires_grad = False
135
 
 
141
  code_dim=self.encoder_semantic_dim, output_channels=self.semantic_dim, decode_channels=self.semantic_dim
142
  )
143
 
144
+ if isinstance(bins, int):
 
145
  self.quantizer = ResidualVectorQuantizer(
146
  dimension=self.quantizer_dim, codebook_dim=codebook_dim, n_q=n_q, bins=bins
147
  )
148
  self.quantizer_type = "RVQ"
149
+ else:
150
  self.quantizer = ResidualFSQ(dim=self.quantizer_dim, levels=bins, num_quantizers=n_q)
151
  self.quantizer_type = "RFSQ"
152
 
 
 
 
 
153
 
154
  self.fc_prior = WNLinear(D + self.encoder_semantic_dim, self.quantizer_dim)
155
  self.fc_post1 = WNLinear(self.quantizer_dim, self.encoder_semantic_dim)
 
200
  self.semantic_techer == "hubert_base"
201
  or self.semantic_techer == "hubert_base_general"
202
  or self.semantic_techer == "wavlm_base_plus"
203
+ or self.semantic_techer == "mHubert_base"
204
  ):
205
  x = x[:, 0, :]
206
  x = F.pad(x, (160, 160))
207
  target = self.semantic_model(x, output_hidden_states=True).hidden_states
208
+ target = torch.stack(target, dim=1)
209
 
 
210
  target = target.mean(1)
 
 
 
211
 
212
  elif self.semantic_techer == "w2v_bert2":
213
  target = self.semantic_model(x)
 
263
 
264
  return o, commit_loss, semantic_recon_loss, None
265
 
266
+ def encode(self, audio_path_or_wv, sr=44100, loudness_normalize=False, loudness_threshold=-23.0):
267
  if isinstance(audio_path_or_wv, str):
268
  wv, sr = librosa.load(audio_path_or_wv, mono=True, sr=None)
269
  else:
 
321
  quantized, codes = self.quantizer(e)
322
  codes = codes.permute(0, 2, 1)
323
 
 
324
  return EncodedResult(codes)
325
 
326
  def decode(self, vq_code: torch.Tensor) -> torch.Tensor:
 
337
  return o.cpu().numpy()
338
 
339
 
340
+ # def load_higgs_audio_tokenizer(tokenizer_name_or_path, device="cuda"): # not used here due to changes
341
+ # is_local = os.path.exists(tokenizer_name_or_path)
342
+ # if not is_local:
343
+ # tokenizer_path = snapshot_download(tokenizer_name_or_path)
344
+ # else:
345
+ # tokenizer_path = tokenizer_name_or_path
346
+ # config_path = os.path.join(tokenizer_path, "config.json")
347
+ # model_path = os.path.join(tokenizer_path, "model.pth")
348
+ # config = json.load(open(config_path))
349
+ # model = HiggsAudioTokenizer(
350
+ # **config,
351
+ # device=device,
352
+ # )
353
+ # parameter_dict = torch.load(model_path, map_location=device, weights_only=False)
354
+ # model.load_state_dict(parameter_dict, strict=False)
355
+ # model.to(device)
356
+ # model.eval()
357
+ # return model