yangwang825 commited on
Commit
c976bf2
·
verified ·
1 Parent(s): 726c4db

Upload model

Browse files
Files changed (3) hide show
  1. config.json +5 -3
  2. model.safetensors +3 -0
  3. modeling_wav2vec2_spkreg.py +764 -0
config.json CHANGED
@@ -1,4 +1,5 @@
1
  {
 
2
  "activation_dropout": 0.0,
3
  "adapter_attn_dim": null,
4
  "adapter_kernel_size": 3,
@@ -6,11 +7,12 @@
6
  "add_adapter": false,
7
  "apply_spec_augment": true,
8
  "architectures": [
9
- "Wav2Vec2ForPreTraining"
10
  ],
11
  "attention_dropout": 0.1,
12
  "auto_map": {
13
- "AutoConfig": "configuration_wav2vec2_spkreg.Wav2Vec2SpkRegConfig"
 
14
  },
15
  "bos_token_id": 1,
16
  "classifier_proj_size": 256,
@@ -56,7 +58,6 @@
56
  "feat_quantizer_dropout": 0.0,
57
  "final_dropout": 0.0,
58
  "freeze_feat_extract_train": true,
59
- "gradient_checkpointing": true,
60
  "hidden_act": "gelu",
61
  "hidden_dropout": 0.1,
62
  "hidden_size": 768,
@@ -119,6 +120,7 @@
119
  1,
120
  1
121
  ],
 
122
  "transformers_version": "4.46.2",
123
  "use_weighted_layer_sum": false,
124
  "vocab_size": 32,
 
1
  {
2
+ "_name_or_path": "facebook/wav2vec2-base",
3
  "activation_dropout": 0.0,
4
  "adapter_attn_dim": null,
5
  "adapter_kernel_size": 3,
 
7
  "add_adapter": false,
8
  "apply_spec_augment": true,
9
  "architectures": [
10
+ "Wav2Vec2SpkRegModel"
11
  ],
12
  "attention_dropout": 0.1,
13
  "auto_map": {
14
+ "AutoConfig": "configuration_wav2vec2_spkreg.Wav2Vec2SpkRegConfig",
15
+ "AutoModel": "modeling_wav2vec2_spkreg.Wav2Vec2SpkRegModel"
16
  },
17
  "bos_token_id": 1,
18
  "classifier_proj_size": 256,
 
58
  "feat_quantizer_dropout": 0.0,
59
  "final_dropout": 0.0,
60
  "freeze_feat_extract_train": true,
 
61
  "hidden_act": "gelu",
62
  "hidden_dropout": 0.1,
63
  "hidden_size": 768,
 
120
  1,
121
  1
122
  ],
123
+ "torch_dtype": "float32",
124
  "transformers_version": "4.46.2",
125
  "use_weighted_layer_sum": false,
126
  "vocab_size": 32,
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:132ac7f4ad2de4d6652f6f6b25354d0f4f22dbd7a8e94d9e03dd4e2518591ca9
3
+ size 377510584
modeling_wav2vec2_spkreg.py ADDED
@@ -0,0 +1,764 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+ from typing import Union, Tuple, Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from transformers.modeling_utils import PreTrainedModel
10
+ from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_13
11
+ from transformers.modeling_outputs import SequenceClassifierOutput, Wav2Vec2BaseModelOutput
12
+ from transformers.models.wav2vec2.modeling_wav2vec2 import (
13
+ Wav2Vec2ForPreTraining,
14
+ Wav2Vec2GumbelVectorQuantizer,
15
+ Wav2Vec2PositionalConvEmbedding,
16
+ Wav2Vec2FeatureProjection,
17
+ Wav2Vec2AttnAdapterLayer,
18
+ Wav2Vec2ForCTC,
19
+ Wav2Vec2FeatureEncoder,
20
+ Wav2Vec2EncoderStableLayerNorm,
21
+ Wav2Vec2Encoder,
22
+ Wav2Vec2Adapter,
23
+ safe_load_file,
24
+ _compute_mask_indices,
25
+ _HIDDEN_STATES_START_POSITION,
26
+ WAV2VEC2_ADAPTER_SAFE_FILE,
27
+ WAV2VEC2_ADAPTER_PT_FILE
28
+ )
29
+ from transformers.utils import (
30
+ cached_file,
31
+ is_safetensors_available,
32
+ logging,
33
+ )
34
+
35
+ from .configuration_wav2vec2_spkreg import Wav2Vec2SpkRegConfig
36
+
37
+ logger = logging.get_logger(__name__)
38
+
39
+
40
+ class Wav2Vec2SpkRegPreTrainedModel(PreTrainedModel):
41
+ """
42
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
43
+ models.
44
+ """
45
+
46
+ config_class = Wav2Vec2SpkRegConfig
47
+ base_model_prefix = "wav2vec2"
48
+ main_input_name = "input_values"
49
+ supports_gradient_checkpointing = True
50
+ _supports_flash_attn_2 = True
51
+ _supports_sdpa = True
52
+
53
+ def _init_weights(self, module):
54
+ """Initialize the weights"""
55
+ # Wav2Vec2ForPreTraining last 2 linear layers need standard Linear init.
56
+ if isinstance(module, Wav2Vec2ForPreTraining):
57
+ module.project_hid.reset_parameters()
58
+ module.project_q.reset_parameters()
59
+ module.project_hid._is_hf_initialized = True
60
+ module.project_q._is_hf_initialized = True
61
+ # gumbel softmax requires special init
62
+ elif isinstance(module, Wav2Vec2GumbelVectorQuantizer):
63
+ module.weight_proj.weight.data.normal_(mean=0.0, std=1)
64
+ module.weight_proj.bias.data.zero_()
65
+ nn.init.uniform_(module.codevectors)
66
+ elif isinstance(module, Wav2Vec2PositionalConvEmbedding):
67
+ nn.init.normal_(
68
+ module.conv.weight,
69
+ mean=0,
70
+ std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
71
+ )
72
+ nn.init.constant_(module.conv.bias, 0)
73
+ elif isinstance(module, Wav2Vec2FeatureProjection):
74
+ k = math.sqrt(1 / module.projection.in_features)
75
+ nn.init.uniform_(module.projection.weight, a=-k, b=k)
76
+ nn.init.uniform_(module.projection.bias, a=-k, b=k)
77
+ elif isinstance(module, nn.Linear):
78
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
79
+
80
+ if module.bias is not None:
81
+ module.bias.data.zero_()
82
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
83
+ module.bias.data.zero_()
84
+ module.weight.data.fill_(1.0)
85
+ elif isinstance(module, nn.Conv1d):
86
+ nn.init.kaiming_normal_(module.weight)
87
+
88
+ if module.bias is not None:
89
+ k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
90
+ nn.init.uniform_(module.bias, a=-k, b=k)
91
+
92
+ def _get_feat_extract_output_lengths(
93
+ self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None
94
+ ):
95
+ """
96
+ Computes the output length of the convolutional layers
97
+ """
98
+
99
+ add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
100
+
101
+ def _conv_out_length(input_length, kernel_size, stride):
102
+ # 1D convolutional layer output length formula taken
103
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
104
+ return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
105
+
106
+ for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
107
+ input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
108
+
109
+ if add_adapter:
110
+ for _ in range(self.config.num_adapter_layers):
111
+ input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
112
+
113
+ return input_lengths
114
+
115
+ def _get_feature_vector_attention_mask(
116
+ self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None
117
+ ):
118
+ # Effectively attention_mask.sum(-1), but not inplace to be able to run
119
+ # on inference mode.
120
+ non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
121
+
122
+ output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
123
+ output_lengths = output_lengths.to(torch.long)
124
+
125
+ batch_size = attention_mask.shape[0]
126
+
127
+ attention_mask = torch.zeros(
128
+ (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
129
+ )
130
+ # these two operations makes sure that all values before the output lengths idxs are attended to
131
+ attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
132
+ attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
133
+ return attention_mask
134
+
135
+ def _get_adapters(self):
136
+ if self.config.adapter_attn_dim is None:
137
+ raise ValueError(f"{self.__class__} has no adapter layers. Make sure to define `config.adapter_attn_dim`.")
138
+
139
+ adapter_weights = {}
140
+ for name, module in self.named_modules():
141
+ if isinstance(module, Wav2Vec2AttnAdapterLayer):
142
+ for param_name, param in module.named_parameters():
143
+ adapter_weights[".".join([name, param_name])] = param
144
+
145
+ if isinstance(self, Wav2Vec2ForCTC):
146
+ for name, param in self.lm_head.named_parameters():
147
+ adapter_weights[".".join(["lm_head", name])] = param
148
+
149
+ return adapter_weights
150
+
151
+ def init_adapter_layers(self):
152
+ """
153
+ (Re-)initialize attention adapter layers and lm head for adapter-only fine-tuning
154
+ """
155
+ # init attention adapters
156
+ for module in self.modules():
157
+ if isinstance(module, Wav2Vec2AttnAdapterLayer):
158
+ self._init_weights(module)
159
+
160
+ # init lm head
161
+ if isinstance(self, Wav2Vec2ForCTC):
162
+ self._init_weights(self.lm_head)
163
+
164
+ def load_adapter(self, target_lang: str, force_load=True, **kwargs):
165
+ r"""
166
+ Load a language adapter model from a pre-trained adapter model.
167
+
168
+ Parameters:
169
+ target_lang (`str`):
170
+ Has to be a language id of an existing adapter weight. Adapter weights are stored in the format
171
+ adapter.<lang>.safetensors or adapter.<lang>.bin
172
+ force_load (`bool`, defaults to `True`):
173
+ Whether the weights shall be loaded even if `target_lang` matches `self.target_lang`.
174
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
175
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
176
+ standard cache should not be used.
177
+ force_download (`bool`, *optional*, defaults to `False`):
178
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
179
+ cached versions if they exist.
180
+ resume_download:
181
+ Deprecated and ignored. All downloads are now resumed by default when possible.
182
+ Will be removed in v5 of Transformers.
183
+ proxies (`Dict[str, str]`, *optional*):
184
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
185
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
186
+ local_files_only(`bool`, *optional*, defaults to `False`):
187
+ Whether or not to only look at local files (i.e., do not try to download the model).
188
+ token (`str` or `bool`, *optional*):
189
+ The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
190
+ the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
191
+ revision (`str`, *optional*, defaults to `"main"`):
192
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
193
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
194
+ identifier allowed by git.
195
+
196
+ <Tip>
197
+
198
+ To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>"`.
199
+
200
+ </Tip>
201
+
202
+ mirror (`str`, *optional*):
203
+ Mirror source to accelerate downloads in China. If you are from China and have an accessibility
204
+ problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
205
+ Please refer to the mirror site for more information.
206
+
207
+ <Tip>
208
+
209
+ Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
210
+ use this method in a firewalled environment.
211
+
212
+ </Tip>
213
+
214
+ Examples:
215
+
216
+ ```python
217
+ >>> from transformers import Wav2Vec2ForCTC, AutoProcessor
218
+
219
+ >>> ckpt = "facebook/mms-1b-all"
220
+ >>> processor = AutoProcessor.from_pretrained(ckpt)
221
+ >>> model = Wav2Vec2ForCTC.from_pretrained(ckpt, target_lang="eng")
222
+ >>> # set specific language
223
+ >>> processor.tokenizer.set_target_lang("spa")
224
+ >>> model.load_adapter("spa")
225
+ ```
226
+ """
227
+ if self.config.adapter_attn_dim is None:
228
+ raise ValueError(f"Cannot load_adapter for {target_lang} if `config.adapter_attn_dim` is not defined.")
229
+
230
+ if target_lang == self.target_lang and not force_load:
231
+ logger.warning(f"Adapter weights are already set to {target_lang}.")
232
+ return
233
+
234
+ cache_dir = kwargs.pop("cache_dir", None)
235
+ force_download = kwargs.pop("force_download", False)
236
+ resume_download = kwargs.pop("resume_download", None)
237
+ proxies = kwargs.pop("proxies", None)
238
+ local_files_only = kwargs.pop("local_files_only", False)
239
+ token = kwargs.pop("token", None)
240
+ use_auth_token = kwargs.pop("use_auth_token", None)
241
+ revision = kwargs.pop("revision", None)
242
+ use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
243
+
244
+ if use_auth_token is not None:
245
+ warnings.warn(
246
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
247
+ FutureWarning,
248
+ )
249
+ if token is not None:
250
+ raise ValueError(
251
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
252
+ )
253
+ token = use_auth_token
254
+
255
+ model_path_or_id = self.config._name_or_path
256
+ state_dict = None
257
+
258
+ # 1. Let's first try loading a safetensors adapter weight
259
+ if use_safetensors is not False:
260
+ filepath = WAV2VEC2_ADAPTER_SAFE_FILE.format(target_lang)
261
+
262
+ try:
263
+ weight_path = cached_file(
264
+ model_path_or_id,
265
+ filename=filepath,
266
+ force_download=force_download,
267
+ resume_download=resume_download,
268
+ proxies=proxies,
269
+ local_files_only=local_files_only,
270
+ token=token,
271
+ revision=revision,
272
+ cache_dir=cache_dir,
273
+ )
274
+
275
+ state_dict = safe_load_file(weight_path)
276
+
277
+ except EnvironmentError:
278
+ if use_safetensors:
279
+ # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
280
+ # to the original exception.
281
+ raise
282
+
283
+ except Exception:
284
+ # For any other exception, we throw a generic error.
285
+ if use_safetensors:
286
+ raise EnvironmentError(
287
+ f"Can't load the model for '{model_path_or_id}'. If you were trying to load it"
288
+ " from 'https://huggingface.co/models', make sure you don't have a local directory with the"
289
+ f" same name. Otherwise, make sure '{model_path_or_id}' is the correct path to a"
290
+ f" directory containing a file named {filepath}."
291
+ )
292
+
293
+ # 2. If this didn't work let's try loading a PyTorch adapter weight
294
+ if state_dict is None:
295
+ filepath = WAV2VEC2_ADAPTER_PT_FILE.format(target_lang)
296
+
297
+ try:
298
+ weight_path = cached_file(
299
+ model_path_or_id,
300
+ filename=filepath,
301
+ force_download=force_download,
302
+ resume_download=resume_download,
303
+ proxies=proxies,
304
+ local_files_only=local_files_only,
305
+ token=token,
306
+ revision=revision,
307
+ cache_dir=cache_dir,
308
+ )
309
+
310
+ weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
311
+ state_dict = torch.load(
312
+ weight_path,
313
+ map_location="cpu",
314
+ **weights_only_kwarg,
315
+ )
316
+
317
+ except EnvironmentError:
318
+ # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
319
+ # to the original exception.
320
+ raise
321
+
322
+ except Exception:
323
+ # For any other exception, we throw a generic error.
324
+ raise EnvironmentError(
325
+ f"Can't load the model for '{model_path_or_id}'. If you were trying to load it"
326
+ " from 'https://huggingface.co/models', make sure you don't have a local directory with the"
327
+ f" same name. Otherwise, make sure '{model_path_or_id}' is the correct path to a"
328
+ f" directory containing a file named {filepath}."
329
+ )
330
+
331
+ adapter_weights = self._get_adapters()
332
+ unexpected_keys = set(state_dict.keys()) - set(adapter_weights.keys())
333
+ missing_keys = set(adapter_weights.keys()) - set(state_dict.keys())
334
+
335
+ if len(unexpected_keys) > 0:
336
+ raise ValueError(f"The adapter weights {weight_path} has unexpected keys: {', '.join(unexpected_keys)}.")
337
+ elif len(missing_keys) > 0:
338
+ raise ValueError(f"The adapter weights {weight_path} has missing keys: {', '.join(missing_keys)}.")
339
+
340
+ # make sure now vocab size is correct
341
+ target_vocab_size = state_dict["lm_head.weight"].shape[0]
342
+ if target_vocab_size != self.config.vocab_size:
343
+ self.lm_head = nn.Linear(
344
+ self.config.output_hidden_size, target_vocab_size, device=self.device, dtype=self.dtype
345
+ )
346
+ self.config.vocab_size = target_vocab_size
347
+
348
+ # make sure that adapter weights are put in exactly the same precision and device placement and overwritten adapter weights
349
+ state_dict = {k: v.to(adapter_weights[k]) for k, v in state_dict.items()}
350
+ self.load_state_dict(state_dict, strict=False)
351
+
352
+ # set target language corectly
353
+ self.target_lang = target_lang
354
+
355
+
356
+ class Wav2Vec2SpkRegModel(Wav2Vec2SpkRegPreTrainedModel):
357
+
358
+ def __init__(self, config: Wav2Vec2SpkRegConfig):
359
+ super().__init__(config)
360
+ self.config = config
361
+ self.feature_extractor = Wav2Vec2FeatureEncoder(config)
362
+ self.feature_projection = Wav2Vec2FeatureProjection(config)
363
+
364
+ # model only needs masking vector if mask prob is > 0.0
365
+ if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
366
+ self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
367
+
368
+ if config.do_stable_layer_norm:
369
+ self.encoder = Wav2Vec2EncoderStableLayerNorm(config)
370
+ else:
371
+ self.encoder = Wav2Vec2Encoder(config)
372
+
373
+ self.adapter = Wav2Vec2Adapter(config) if config.add_adapter else None
374
+
375
+ # Initialize weights and apply final processing
376
+ self.post_init()
377
+
378
+ def freeze_feature_extractor(self):
379
+ """
380
+ Calling this function will disable the gradient computation for the feature encoder so that its parameters will
381
+ not be updated during training.
382
+ """
383
+ warnings.warn(
384
+ "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
385
+ "Please use the equivalent `freeze_feature_encoder` method instead.",
386
+ FutureWarning,
387
+ )
388
+ self.freeze_feature_encoder()
389
+
390
+ def freeze_feature_encoder(self):
391
+ """
392
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
393
+ not be updated during training.
394
+ """
395
+ self.feature_extractor._freeze_parameters()
396
+
397
+ def _mask_hidden_states(
398
+ self,
399
+ hidden_states: torch.FloatTensor,
400
+ mask_time_indices: Optional[torch.FloatTensor] = None,
401
+ attention_mask: Optional[torch.LongTensor] = None,
402
+ ):
403
+ """
404
+ Masks extracted features along time axis and/or along feature axis according to
405
+ [SpecAugment](https://arxiv.org/abs/1904.08779).
406
+ """
407
+
408
+ # `config.apply_spec_augment` can set masking to False
409
+ if not getattr(self.config, "apply_spec_augment", True):
410
+ return hidden_states
411
+
412
+ # generate indices & apply SpecAugment along time axis
413
+ batch_size, sequence_length, hidden_size = hidden_states.size()
414
+
415
+ if mask_time_indices is not None:
416
+ # apply SpecAugment along time axis with given mask_time_indices
417
+ hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
418
+ elif self.config.mask_time_prob > 0 and self.training:
419
+ mask_time_indices = _compute_mask_indices(
420
+ (batch_size, sequence_length),
421
+ mask_prob=self.config.mask_time_prob,
422
+ mask_length=self.config.mask_time_length,
423
+ attention_mask=attention_mask,
424
+ min_masks=self.config.mask_time_min_masks,
425
+ )
426
+ mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
427
+ hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
428
+
429
+ if self.config.mask_feature_prob > 0 and self.training:
430
+ # generate indices & apply SpecAugment along feature axis
431
+ mask_feature_indices = _compute_mask_indices(
432
+ (batch_size, hidden_size),
433
+ mask_prob=self.config.mask_feature_prob,
434
+ mask_length=self.config.mask_feature_length,
435
+ min_masks=self.config.mask_feature_min_masks,
436
+ )
437
+ mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
438
+ mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
439
+ hidden_states[mask_feature_indices] = 0
440
+
441
+ return hidden_states
442
+
443
+ def forward(
444
+ self,
445
+ input_values: Optional[torch.Tensor],
446
+ attention_mask: Optional[torch.Tensor] = None,
447
+ mask_time_indices: Optional[torch.FloatTensor] = None,
448
+ output_attentions: Optional[bool] = None,
449
+ output_hidden_states: Optional[bool] = None,
450
+ return_dict: Optional[bool] = None,
451
+ ) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
452
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
453
+ output_hidden_states = (
454
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
455
+ )
456
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
457
+
458
+ extract_features = self.feature_extractor(input_values)
459
+ extract_features = extract_features.transpose(1, 2)
460
+
461
+ if attention_mask is not None:
462
+ # compute reduced attention_mask corresponding to feature vectors
463
+ attention_mask = self._get_feature_vector_attention_mask(
464
+ extract_features.shape[1], attention_mask, add_adapter=False
465
+ )
466
+
467
+ hidden_states, extract_features = self.feature_projection(extract_features)
468
+ hidden_states = self._mask_hidden_states(
469
+ hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
470
+ )
471
+
472
+ encoder_outputs = self.encoder(
473
+ hidden_states,
474
+ attention_mask=attention_mask,
475
+ output_attentions=output_attentions,
476
+ output_hidden_states=output_hidden_states,
477
+ return_dict=return_dict,
478
+ )
479
+
480
+ hidden_states = encoder_outputs[0]
481
+
482
+ if self.adapter is not None:
483
+ hidden_states = self.adapter(hidden_states)
484
+
485
+ if not return_dict:
486
+ return (hidden_states, extract_features) + encoder_outputs[1:]
487
+
488
+ return Wav2Vec2BaseModelOutput(
489
+ last_hidden_state=hidden_states,
490
+ extract_features=extract_features,
491
+ hidden_states=encoder_outputs.hidden_states,
492
+ attentions=encoder_outputs.attentions,
493
+ )
494
+
495
+
496
+ class AngularLinear(nn.Module):
497
+
498
+ def __init__(self, in_features: int, out_features: int):
499
+ super(AngularLinear, self).__init__()
500
+ self.in_features = in_features
501
+ self.out_features = out_features
502
+ self.weight = torch.nn.Parameter(
503
+ torch.FloatTensor(out_features, in_features), requires_grad=True
504
+ )
505
+ nn.init.xavier_normal_(self.weight, gain=1)
506
+
507
+ def forward(
508
+ self,
509
+ inputs: torch.Tensor,
510
+ ):
511
+ # Calculation of cos(theta)
512
+ cosine = F.linear(F.normalize(inputs), F.normalize(self.weight))
513
+ return cosine
514
+
515
+ def extra_repr(self) -> str:
516
+ return 'in_features={}, out_features={}'.format(
517
+ self.in_features, self.out_features
518
+ )
519
+
520
+
521
+ class AMSoftmaxLoss(nn.Module):
522
+ """Additive Margin Softmax
523
+
524
+ Paper: Wang, Feng, et al. "Additive margin softmax for face verification."
525
+ IEEE Signal Processing Letters 25.7 (2018): 926-930.
526
+ """
527
+ def __init__(
528
+ self,
529
+ num_labels: int,
530
+ scale: float = 30.0,
531
+ margin: float = 0.35,
532
+ ):
533
+ """
534
+ Args:
535
+ num_classes: Number of classes (output dimension)
536
+ scale: Scaling factor for logits (default: 30.0)
537
+ margin: Angular margin (default: 0.35)
538
+ """
539
+ super(AMSoftmaxLoss, self).__init__()
540
+ self.num_labels = num_labels
541
+ self.scale = scale
542
+ self.margin = margin
543
+
544
+ def forward(
545
+ self,
546
+ inputs: torch.Tensor,
547
+ targets: torch.Tensor,
548
+ label_smoothing: float = 0.0,
549
+ reduction: str = "mean"
550
+ ):
551
+ """
552
+ Args:
553
+ inputs: Input features of shape (batch_size, num_labels)
554
+ targets: Ground truth labels of shape (batch_size)
555
+ label_smoothing: Label smoothing factor (default: 0.0)
556
+ reduction: Reduction method (default: "mean")
557
+ Returns:
558
+ Loss value
559
+ """
560
+ # `inputs` are the outputs from AngularLinear()
561
+ cosine = inputs
562
+ psi = cosine - self.margin
563
+ one_hot = nn.functional.one_hot(targets, self.num_labels)
564
+ outputs = self.scale * torch.where(one_hot.bool(), psi, cosine)
565
+ loss = F.cross_entropy(outputs, targets, label_smoothing=label_smoothing, reduction=reduction)
566
+ return loss
567
+
568
+
569
+ class AAMSoftmaxLoss(nn.Module):
570
+ """Additive Angular Margin Softmax.
571
+
572
+ Paper: Deng, Jiankang, et al. "Arcface: Additive angular margin loss for deep face recognition."
573
+ Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2019.
574
+ """
575
+ def __init__(
576
+ self,
577
+ num_labels: int,
578
+ scale: float = 30.0,
579
+ margin: float = 0.35,
580
+ easy_margin: bool = False
581
+ ):
582
+ """
583
+ Args:
584
+ num_classes: Number of classes (output dimension)
585
+ scale: Scaling factor for logits (default: 30.0)
586
+ margin: Angular margin (default: 0.35)
587
+ easy_margin: Use the easy margin loss (default: False)
588
+ """
589
+ super(AAMSoftmaxLoss, self).__init__()
590
+ self.num_labels = num_labels
591
+ self.scale = scale
592
+ self.margin = margin
593
+ self.easy_margin = easy_margin
594
+
595
+ def forward(
596
+ self,
597
+ inputs: torch.Tensor,
598
+ targets: torch.Tensor,
599
+ label_smoothing: float = 0.0,
600
+ reduction: str = "mean"
601
+ ):
602
+ """
603
+ Args:
604
+ inputs: Input features of shape (batch_size, num_labels)
605
+ targets: Ground truth labels of shape (batch_size)
606
+ label_smoothing: Label smoothing factor (default: 0.0)
607
+ reduction: Reduction method (default: "mean")
608
+ Returns:
609
+ Loss value
610
+ """
611
+ # Calculation of cos(theta + m) where inputs are the outputs from AngularLinear()
612
+ cosine = inputs
613
+ sine = torch.sqrt((1.0 - torch.mul(cosine, cosine)).clamp(0, 1))
614
+ phi = cosine * math.cos(self.margin) - sine * math.sin(self.margin)
615
+
616
+ # make the function cos(theta+m) monotonic decreasing while theta in [0°,180°]
617
+ th = math.cos(math.pi - self.margin)
618
+ mm = math.sin(math.pi - self.margin) * self.margin
619
+
620
+ if self.easy_margin:
621
+ phi = torch.where(cosine > 0, phi, cosine)
622
+ else:
623
+ phi = torch.where((cosine - th) > 0, phi, cosine - mm)
624
+
625
+ one_hot = torch.zeros_like(cosine)
626
+ one_hot.scatter_(1, targets.view(-1, 1), 1)
627
+ outputs = (one_hot * phi) + ((1.0 - one_hot) * cosine)
628
+ outputs = outputs * self.scale
629
+
630
+ loss = F.cross_entropy(outputs, targets, label_smoothing=label_smoothing, reduction=reduction)
631
+ return loss
632
+
633
+
634
+ class Wav2Vec2SpkRegForSequenceClassification(Wav2Vec2SpkRegPreTrainedModel):
635
+
636
+ def __init__(self, config):
637
+ super().__init__(config)
638
+
639
+ if hasattr(config, "add_adapter") and config.add_adapter:
640
+ raise ValueError(
641
+ "Sequence classification does not support the use of Wav2Vec2 adapters (config.add_adapter=True)"
642
+ )
643
+ self.wav2vec2 = Wav2Vec2SpkRegModel(config)
644
+ num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
645
+ if config.use_weighted_layer_sum:
646
+ self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
647
+ self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
648
+
649
+ if self.config.loss_fct == 'cross_entropy':
650
+ self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
651
+ elif self.config.loss_fct == 'additive_margin':
652
+ self.classifier = AngularLinear(config.classifier_proj_size, config.num_labels)
653
+ elif self.config.loss_fct == 'additive_margin':
654
+ self.classifier = AngularLinear(config.classifier_proj_size, config.num_labels)
655
+ else:
656
+ raise ValueError(f"Unsupported loss function: {self.config.loss_fct}")
657
+
658
+ # Initialize weights and apply final processing
659
+ self.post_init()
660
+
661
+ def freeze_feature_extractor(self):
662
+ """
663
+ Calling this function will disable the gradient computation for the feature encoder so that its parameters will
664
+ not be updated during training.
665
+ """
666
+ warnings.warn(
667
+ "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
668
+ "Please use the equivalent `freeze_feature_encoder` method instead.",
669
+ FutureWarning,
670
+ )
671
+ self.freeze_feature_encoder()
672
+
673
+ def freeze_feature_encoder(self):
674
+ """
675
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
676
+ not be updated during training.
677
+ """
678
+ self.wav2vec2.feature_extractor._freeze_parameters()
679
+
680
+ def freeze_base_model(self):
681
+ """
682
+ Calling this function will disable the gradient computation for the base model so that its parameters will not
683
+ be updated during training. Only the classification head will be updated.
684
+ """
685
+ for param in self.wav2vec2.parameters():
686
+ param.requires_grad = False
687
+
688
+ def forward(
689
+ self,
690
+ input_values: Optional[torch.Tensor],
691
+ attention_mask: Optional[torch.Tensor] = None,
692
+ output_attentions: Optional[bool] = None,
693
+ output_hidden_states: Optional[bool] = None,
694
+ return_dict: Optional[bool] = None,
695
+ labels: Optional[torch.Tensor] = None,
696
+ ) -> Union[Tuple, SequenceClassifierOutput]:
697
+ r"""
698
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
699
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
700
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
701
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
702
+ """
703
+
704
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
705
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
706
+
707
+ outputs = self.wav2vec2(
708
+ input_values,
709
+ attention_mask=attention_mask,
710
+ output_attentions=output_attentions,
711
+ output_hidden_states=output_hidden_states,
712
+ return_dict=return_dict,
713
+ )
714
+
715
+ if self.config.use_weighted_layer_sum:
716
+ hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
717
+ hidden_states = torch.stack(hidden_states, dim=1)
718
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
719
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
720
+ else:
721
+ hidden_states = outputs[0]
722
+
723
+ hidden_states = self.projector(hidden_states)
724
+ if attention_mask is None:
725
+ pooled_output = hidden_states.mean(dim=1)
726
+ else:
727
+ padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
728
+ hidden_states[~padding_mask] = 0.0
729
+ pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
730
+
731
+ logits = self.classifier(pooled_output)
732
+
733
+ loss = None
734
+ if labels is not None:
735
+ if self.loss_fct == 'cross_entropy':
736
+ loss_fct = nn.CrossEntropyLoss(
737
+ label_smoothing=self.config.label_smoothing,
738
+ reduction=self.config.reduction
739
+ )
740
+ elif self.loss_fct == 'additive_margin':
741
+ loss_fct = AMSoftmaxLoss(
742
+ self.config.num_labels, self.config.scale, self.config.margin
743
+ )
744
+ elif self.loss_fct == 'additive_angular_margin':
745
+ loss_fct = AAMSoftmaxLoss(
746
+ self.config.num_labels, self.config.scale, self.config.margin, self.config.easy_margin
747
+ )
748
+ loss = loss_fct(
749
+ logits.view(-1, self.config.num_labels),
750
+ labels.view(-1),
751
+ label_smoothing=self.config.label_smoothing,
752
+ reduction=self.config.reduction
753
+ )
754
+
755
+ if not return_dict:
756
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
757
+ return ((loss,) + output) if loss is not None else output
758
+
759
+ return SequenceClassifierOutput(
760
+ loss=loss,
761
+ logits=logits,
762
+ hidden_states=outputs.hidden_states,
763
+ attentions=outputs.attentions,
764
+ )