jiangab commited on
Commit
8960e0d
·
verified ·
1 Parent(s): 55e7be7

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/model_pipe.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/rmis_curve.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,86 @@
1
  ---
 
2
  license: mit
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ language: en
3
  license: mit
4
+ tags:
5
+ - pytorch
6
  ---
7
+
8
+ <h1 align="center">
9
+ FISHER
10
+ </h1>
11
+
12
+ <div align="center">
13
+ <img src="assets/rmis_curve.png" alt="Model Performances on the RMIS Benchmark" style="width:80%; max-width: 1000px">
14
+ </div>
15
+
16
+
17
+ ## Introduction
18
+
19
+ <div align="center">
20
+ <img src="assets/model_pipe.png" alt="Model Performances on the RMIS Benchmark" style="width:100%; max-width: 1500px">
21
+ </div>
22
+
23
+ FISHER is a **F**oundation model for **I**ndustrial **S**ignal compre**HE**nsive **R**epresentation, which models heterogeneous industrial signals (sound, vibration, voltage, etc.) in a unified manner. FISHER accepts arbitrary sampling rates and models the increment of sampling rate as the concatenation of sub-band information, which first splits a STFT spectrogram into sub-bands before processsing it by the ViT encoder. FISHER is trained by teacher student EMA self-distillation.
24
+
25
+ To evaluate the model, we develop the RMIS benchmark, which will also be open-sourced in the near future. FISHER achieves the SOTA performances on the RMIS benchmark with much more efficient scaling properties.
26
+
27
+ ## Inference
28
+
29
+ Please use the following code to infer the signal representation by FISHER.
30
+
31
+ ```python
32
+ import torch
33
+ import torchaudio
34
+ import torch.nn.functional as F
35
+ from transformers import AutoModel
36
+
37
+ model = AutoModel.from_pretrained('jiangab/FISHER-mini-0723', trust_remote_code=True)
38
+ model = model.cuda()
39
+ model.eval()
40
+
41
+ wav, sr = torchaudio.load('/path/to/local/signal.wav')
42
+ # You can replace it with your custom loading function for other signals
43
+
44
+ wav = wav - wav.mean()
45
+ STFT = torchaudio.transforms.Spectrogram(
46
+ n_fft=25 * sr // 1000,
47
+ win_length=None,
48
+ hop_length=10 * sr // 1000,
49
+ power=1,
50
+ center=False
51
+ )
52
+ spec = torch.log(torch.abs(STFT(wav)) + 1e-10)
53
+ spec = spec.transpose(-2, -1) # [1, time, freq]
54
+ spec = (spec + 3.017344307886898) / (2.1531635155379805 * 2)
55
+
56
+ # time-wise cutoff
57
+ if spec.shape[-2] > 1024:
58
+ spec = spec[:, :1024]
59
+ # freq-wise padding
60
+ if spec.shape[-1] < model.cfg.band_width:
61
+ spec = F.pad(spec, (0, model.cfg.band_width - spec.shape[-1]))
62
+ spec = spec.unsqueeze(1).cuda()
63
+
64
+ with torch.no_grad():
65
+ # Use autocast for mixed precision inference. You can disable it for full precision.
66
+ with torch.autocast('cuda'):
67
+ repre = model.extract_features(spec)
68
+ print(repre.shape)
69
+ ```
70
+
71
+ ## Acknowledgements
72
+
73
+ FISHER is developed based on [EAT](https://github.com/cwx-worst-one/EAT) and [fairseq](https://github.com/facebookresearch/fairseq). We thank these authors for open-sourcing their works.
74
+
75
+ ## Citation
76
+
77
+ If you find FISHER useful, please cite the following paper.
78
+
79
+ ```bibtex
80
+ @article{fan2025fisher,
81
+ title={FISHER: A Foundation Model for Multi-Modal Industrial Signal Comprehensive Representation},
82
+ author={Fan, Pingyi and Jiang, Anbai and Zhang, Shuwei and Lv, Zhiqiang and Han, Bing and Zheng, Xinhu and Liang, Wenrui and Li, Junjie and Zhang, Wei-Qiang and Qian, Yanmin and Chen, Xie and Lu, Cheng and Liu, Jia},
83
+ journal={arXiv preprint arXiv:2507.16696},
84
+ year={2025}
85
+ }
86
+ ```
assets/icon.jpg ADDED
assets/model_pipe.png ADDED

Git LFS Details

  • SHA256: a1c3a1f1f762135e62b6d97826553a3baa435308e20f3c58940eb1c164f9e355
  • Pointer size: 132 Bytes
  • Size of remote file: 3.43 MB
assets/rmis_curve.png ADDED

Git LFS Details

  • SHA256: 21e78c74d71589376606b149ac27e1a7b2b262197abd6fa20ebcb8555fffb765
  • Pointer size: 131 Bytes
  • Size of remote file: 403 kB
base.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from collections import namedtuple
8
+ from dataclasses import dataclass
9
+ from functools import partial
10
+ from omegaconf import MISSING, II
11
+ from typing import Optional, Callable
12
+ from enum import Enum, auto
13
+
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class Modality(Enum):
19
+ AUDIO = auto()
20
+ IMAGE = auto()
21
+ TEXT = auto()
22
+
23
+
24
+ @dataclass
25
+ class D2vModalityConfig:
26
+ type: Modality = MISSING
27
+ prenet_depth: int = 0
28
+ prenet_layerdrop: float = 0.0
29
+ prenet_dropout: float = 0.0
30
+ start_drop_path_rate: float = 0.0
31
+ end_drop_path_rate: float = 0.0
32
+
33
+ num_extra_tokens: int = 1
34
+ init_extra_token_zero: bool = False
35
+
36
+ mask_noise_std: float = 0.01
37
+ mask_prob_min: Optional[float] = None
38
+ mask_prob: float = 0.8
39
+ inverse_mask: bool = True
40
+ mask_prob_adjust: float = 0.07
41
+ keep_masked_pct: float = 0.0
42
+ flexible_mask: bool = False
43
+
44
+ mask_length: int = 5
45
+ add_masks: bool = False
46
+ remove_masks: bool = False
47
+ mask_dropout: float = 0.0
48
+ encoder_zero_mask: bool = True
49
+
50
+ mask_channel_prob: float = 0.0
51
+ mask_channel_length: int = 64
52
+
53
+ ema_local_encoder: bool = True # used in data2vec_multi
54
+ ema_local_decoder: bool = False
55
+ local_grad_mult: float = 1.0
56
+ flatten: str = 'freq'
57
+ max_length: int = 128
58
+ max_freq: int = 50
59
+
60
+ use_alibi_encoder: bool = False
61
+ alibi_scale: float = 1.0
62
+ learned_alibi: bool = False
63
+ alibi_max_pos: Optional[int] = None
64
+ learned_alibi_scale: bool = False
65
+ learned_alibi_scale_per_head: bool = False
66
+ learned_alibi_scale_per_layer: bool = False
67
+
68
+ num_alibi_heads: int = II("model.num_heads")
69
+ model_depth: int = II("model.depth")
70
+
71
+
72
+ MaskInfo = namedtuple("MaskInfo", ["x_unmasked", "mask", "ids_restore", "ids_keep"])
73
+
74
+
75
+ class ModalitySpecificEncoder(nn.Module):
76
+ def __init__(
77
+ self,
78
+ modality_cfg: D2vModalityConfig,
79
+ embed_dim: int,
80
+ local_encoder: nn.Module,
81
+ project_features: nn.Module,
82
+ fixed_positional_encoder: Optional[nn.Module],
83
+ relative_positional_encoder: Optional[nn.Module], # None
84
+ context_encoder: nn.Module,
85
+ decoder: Optional[nn.Module],
86
+ get_alibi_bias: Optional[Callable[[int, int, str, str], torch.Tensor]],
87
+ ):
88
+ super().__init__()
89
+
90
+ self.modality_cfg = modality_cfg
91
+ self.local_encoder = local_encoder
92
+ self.project_features = project_features
93
+ self.fixed_positional_encoder = fixed_positional_encoder
94
+ self.relative_positional_encoder = relative_positional_encoder
95
+ self.context_encoder = context_encoder
96
+
97
+ self.decoder = decoder
98
+ self.get_alibi_bias = get_alibi_bias if modality_cfg.use_alibi_encoder else None
99
+
100
+ self.local_grad_mult = self.modality_cfg.local_grad_mult
101
+
102
+ self.extra_tokens = None
103
+ if modality_cfg.num_extra_tokens > 0:
104
+ self.extra_tokens = nn.Parameter(
105
+ torch.zeros(1, modality_cfg.num_extra_tokens, embed_dim)
106
+ )
107
+ if not modality_cfg.init_extra_token_zero:
108
+ nn.init.normal_(self.extra_tokens)
109
+ elif self.extra_tokens.size(1) > 1:
110
+ nn.init.normal_(self.extra_tokens[:, 1:])
111
+
112
+ self.alibi_scale = None
113
+ if self.get_alibi_bias is not None:
114
+ self.alibi_scale = nn.Parameter(
115
+ torch.full(
116
+ (
117
+ (modality_cfg.prenet_depth + modality_cfg.model_depth)
118
+ if modality_cfg.learned_alibi_scale_per_layer
119
+ else 1,
120
+ 1,
121
+ self.modality_cfg.num_alibi_heads
122
+ if modality_cfg.learned_alibi_scale_per_head
123
+ else 1,
124
+ 1,
125
+ 1,
126
+ ),
127
+ modality_cfg.alibi_scale,
128
+ dtype=torch.float,
129
+ ),
130
+ requires_grad=modality_cfg.learned_alibi_scale,
131
+ )
132
+
133
+ if modality_cfg.learned_alibi and self.get_alibi_bias is not None:
134
+ assert modality_cfg.alibi_max_pos is not None
135
+ alibi_bias = self.get_alibi_bias(
136
+ batch_size=1,
137
+ time_steps=modality_cfg.alibi_max_pos,
138
+ heads=modality_cfg.num_alibi_heads,
139
+ scale=1.0,
140
+ dtype=torch.float,
141
+ device="cpu",
142
+ )
143
+ self.alibi_bias = nn.Parameter(alibi_bias)
144
+ self.get_alibi_bias = partial(
145
+ _learned_alibi_bias, alibi_bias=self.alibi_bias
146
+ )
147
+
148
+ def upgrade_state_dict_named(self, state_dict, name):
149
+ k = f"{name}.alibi_scale"
150
+ if k in state_dict and state_dict[k].dim() == 4:
151
+ state_dict[k] = state_dict[k].unsqueeze(0)
152
+
153
+ return state_dict
154
+
155
+ def convert_padding_mask(self, x, padding_mask):
156
+ return padding_mask
157
+
158
+ def local_features(self, features):
159
+ x = self.local_encoder(features)
160
+ x = self.project_features(x) # nn.Identity()
161
+ return x
162
+
163
+ def contextualized_features(
164
+ self,
165
+ x,
166
+ padding_mask,
167
+ mask, # True
168
+ remove_masked, # train: True; infer: False
169
+ clone_batch: int = 1,
170
+ mask_seeds: Optional[torch.Tensor] = None,
171
+ precomputed_mask=None,
172
+ ):
173
+
174
+ if padding_mask is not None:
175
+ padding_mask = self.convert_padding_mask(x, padding_mask) # [b,t,f] => [b,seq]
176
+
177
+ local_features = x
178
+ if mask and clone_batch == 1:
179
+ local_features = local_features.clone()
180
+
181
+ orig_B, orig_T, _ = x.shape
182
+ pre_mask_B = orig_B
183
+ mask_info = None
184
+
185
+ x_pos = None
186
+ # x: [B, seq_len, embed_dim]
187
+ if self.fixed_positional_encoder is not None: # models.modules.FixPositionalEncoder
188
+ x = x + self.fixed_positional_encoder(x, padding_mask)[:, :x.size(1), :]
189
+
190
+ if self.relative_positional_encoder is not None:
191
+ x_pos = self.relative_positional_encoder(x)
192
+
193
+ masked_padding_mask = padding_mask
194
+
195
+ alibi_bias = None
196
+ alibi_scale = self.alibi_scale
197
+
198
+ if self.get_alibi_bias is not None:
199
+ alibi_bias = self.get_alibi_bias(
200
+ batch_size=pre_mask_B,
201
+ time_steps=orig_T,
202
+ heads=self.modality_cfg.num_alibi_heads,
203
+ dtype=torch.float32,
204
+ device=x.device,
205
+ )
206
+
207
+ if alibi_scale is not None:
208
+ alibi_scale = alibi_scale.clamp_min(0)
209
+ if alibi_scale.size(0) == 1:
210
+ alibi_bias = alibi_bias * alibi_scale.squeeze(0).type_as(alibi_bias)
211
+ alibi_scale = None
212
+
213
+ if clone_batch > 1:
214
+ alibi_bias = alibi_bias.repeat_interleave(clone_batch, 0)
215
+
216
+ if mask_info is not None and remove_masked:
217
+ alibi_bias = masked_alibi(alibi_bias, mask_info)
218
+
219
+ if self.extra_tokens is not None:
220
+ num = self.extra_tokens.size(1)
221
+ x = torch.cat([self.extra_tokens.expand(x.size(0), -1, -1), x], dim=1)
222
+ if masked_padding_mask is not None:
223
+ # B x T
224
+ masked_padding_mask = F.pad(masked_padding_mask, (num, 0))
225
+ if alibi_bias is not None:
226
+ # B x H x T x T
227
+ alibi_bias = F.pad(alibi_bias, (num, 0, num, 0))
228
+
229
+ x = self.context_encoder(
230
+ x,
231
+ masked_padding_mask,
232
+ alibi_bias,
233
+ alibi_scale[: self.modality_cfg.prenet_depth]
234
+ if alibi_scale is not None
235
+ else None,
236
+ )
237
+
238
+ return {
239
+ "x": x,
240
+ "local_features": local_features,
241
+ "padding_mask": masked_padding_mask,
242
+ "alibi_bias": alibi_bias,
243
+ "alibi_scale": alibi_scale[self.modality_cfg.prenet_depth :]
244
+ if alibi_scale is not None and alibi_scale.size(0) > 1
245
+ else alibi_scale,
246
+ "encoder_mask": mask_info,
247
+ }
248
+
249
+ def forward(
250
+ self,
251
+ features,
252
+ padding_mask,
253
+ mask: bool,
254
+ remove_masked: bool,
255
+ clone_batch: int = 1,
256
+ mask_seeds: Optional[torch.Tensor] = None,
257
+ precomputed_mask=None,
258
+ ):
259
+ x = self.local_features(features) # patch embed
260
+ # x: [bs, time*freq, embed_dim], e.g. [12, 512, 768]
261
+ out = self.contextualized_features(
262
+ x,
263
+ padding_mask,
264
+ mask,
265
+ remove_masked,
266
+ clone_batch,
267
+ mask_seeds,
268
+ precomputed_mask,
269
+ ) # add mask, discarded masked, context encoder (only layer norm)
270
+ return out
271
+
272
+ def reset_parameters(self):
273
+ pass
274
+
275
+ def remove_pretraining_modules(self, keep_decoder=False):
276
+ if not keep_decoder:
277
+ self.decoder = None
278
+
279
+
280
+ def get_annealed_rate(start, end, curr_step, total_steps):
281
+ if curr_step >= total_steps:
282
+ return end
283
+ r = end - start
284
+ pct_remaining = 1 - curr_step / total_steps
285
+ return end - r * pct_remaining
286
+
287
+
288
+
289
+ def get_alibi(
290
+ max_positions: int,
291
+ attention_heads: int,
292
+ dims: int = 1,
293
+ distance: str = "manhattan",
294
+ ):
295
+ def get_slopes(n):
296
+ def get_slopes_power_of_2(n):
297
+ start = 2 ** (-(2 ** -(math.log2(n) - 3)))
298
+ ratio = start
299
+ return [start * ratio**i for i in range(n)]
300
+
301
+ # In the paper, we only train models that have 2^a heads for some
302
+ # a. This function has some good properties that only occur when
303
+ # the input is a power of 2. To maintain that even when the number
304
+ # of heads is not a power of 2, we use this workaround.
305
+ if math.log2(n).is_integer():
306
+ return get_slopes_power_of_2(n)
307
+ else:
308
+ closest_power_of_2 = 2 ** math.floor(math.log2(n))
309
+ return (
310
+ get_slopes_power_of_2(closest_power_of_2)
311
+ + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
312
+ )
313
+
314
+ maxpos = max_positions
315
+ attn_heads = attention_heads
316
+ slopes = torch.Tensor(get_slopes(attn_heads))
317
+
318
+ if dims == 1:
319
+ # prepare alibi position linear bias. Note that wav2vec2 is non
320
+ # autoregressive model so we want a symmetric mask with 0 on the
321
+ # diagonal and other wise linear decreasing valuees
322
+ pos_bias = (
323
+ torch.abs(
324
+ torch.arange(maxpos).unsqueeze(0) - torch.arange(maxpos).unsqueeze(1)
325
+ )
326
+ * -1
327
+ )
328
+ elif dims == 2:
329
+ if distance == "manhattan":
330
+ df = lambda x1, y1, x2, y2: abs(x1 - x2) + abs(y1 - y2)
331
+ elif distance == "euclidean":
332
+ df = lambda x1, y1, x2, y2: math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)
333
+
334
+ n = math.sqrt(max_positions)
335
+ assert n.is_integer(), n
336
+ n = int(n)
337
+
338
+ pos_bias = torch.zeros((max_positions, max_positions))
339
+
340
+ for i in range(n):
341
+ for j in range(n):
342
+ for k in range(n):
343
+ for l in range(n):
344
+ new_x = i * n + j
345
+ new_y = k * n + l
346
+ pos_bias[new_x, new_y] = -df(i, j, k, l)
347
+
348
+ else:
349
+ raise Exception(f"unsupported number of alibi dims: {dims}")
350
+
351
+ alibi_bias = slopes.unsqueeze(1).unsqueeze(1) * pos_bias.unsqueeze(0).expand(
352
+ attn_heads, -1, -1
353
+ )
354
+
355
+ return alibi_bias
356
+
357
+
358
+ def get_alibi_bias(
359
+ alibi_biases,
360
+ batch_size,
361
+ time_steps,
362
+ heads,
363
+ dtype,
364
+ device,
365
+ dims=1,
366
+ distance="manhattan",
367
+ ):
368
+ cache_key = f"{dims}_{heads}_{distance}"
369
+
370
+ buffered = alibi_biases.get(cache_key, None)
371
+
372
+ target_size = heads * batch_size
373
+ if (
374
+ buffered is None
375
+ or buffered.size(0) < target_size
376
+ or buffered.size(1) < time_steps
377
+ or buffered.dtype != dtype
378
+ or buffered.device != device
379
+ ):
380
+ bt = max(time_steps, buffered.size(1) if buffered is not None else 0)
381
+ bn = max(target_size, buffered.size(0) if buffered is not None else 0) // heads
382
+
383
+ buffered = (
384
+ get_alibi(bt, heads, dims=dims, distance=distance)
385
+ .to(dtype=dtype, device=device)
386
+ .repeat(bn, 1, 1)
387
+ )
388
+
389
+ alibi_biases[cache_key] = buffered
390
+
391
+ b = buffered[:target_size, :time_steps, :time_steps]
392
+ b = b.view(batch_size, heads, time_steps, time_steps)
393
+ return b
394
+
395
+
396
+ def _learned_alibi_bias(
397
+ alibi_bias,
398
+ batch_size,
399
+ time_steps,
400
+ heads,
401
+ scale,
402
+ dtype,
403
+ device,
404
+ ):
405
+ assert alibi_bias.size(1) == heads, alibi_bias.shape
406
+ assert alibi_bias.dtype == dtype, alibi_bias.dtype
407
+ assert alibi_bias.device == device, alibi_bias.device
408
+
409
+ if alibi_bias.size(-1) < time_steps:
410
+ psz = math.ceil((time_steps - alibi_bias.size(-1)) / 2)
411
+ alibi_bias = F.pad(alibi_bias, (psz, psz, psz, psz), mode="replicate")
412
+
413
+ alibi_bias = alibi_bias.expand(batch_size, -1, -1, -1) * scale
414
+ return alibi_bias[..., :time_steps, :time_steps]
415
+
416
+
417
+ def masked_alibi(alibi_bias, mask_info):
418
+ H = alibi_bias.size(1)
419
+
420
+ orig_bias = alibi_bias
421
+
422
+ index = mask_info.ids_keep.unsqueeze(1)[..., 0].unsqueeze(-1)
423
+ alibi_bias = torch.gather(
424
+ orig_bias,
425
+ dim=-2,
426
+ index=index.expand(-1, H, -1, mask_info.ids_restore.size(1)),
427
+ )
428
+ alibi_bias = torch.gather(
429
+ alibi_bias,
430
+ dim=-1,
431
+ index=index.transpose(-1, -2).expand(-1, H, alibi_bias.size(-2), -1),
432
+ )
433
+
434
+ return alibi_bias
config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "FISHERModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_fisher.FISHERConfig",
7
+ "AutoModel": "modeling_fisher.FISHERModel"
8
+ },
9
+ "band_width": 100,
10
+ "depth": 12,
11
+ "embed_dim": 256,
12
+ "max_band_per_sample": 64,
13
+ "model_type": "fisher",
14
+ "num_heads": 4,
15
+ "torch_dtype": "float32",
16
+ "transformers_version": "4.53.3"
17
+ }
configuration_fisher.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class FISHERConfig(PretrainedConfig):
5
+ model_type = "fisher"
6
+
7
+ def __init__(
8
+ self,
9
+ band_width=100,
10
+ embed_dim=192,
11
+ num_heads=3,
12
+ max_band_per_sample=64,
13
+ depth=12,
14
+ **kwargs,
15
+ ):
16
+ super().__init__(**kwargs)
17
+
18
+ self.band_width = band_width
19
+ self.embed_dim = embed_dim
20
+ self.depth = depth
21
+ self.num_heads = num_heads
22
+ self.max_band_per_sample = max_band_per_sample
images.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+
5
+ from functools import partial
6
+ from dataclasses import dataclass
7
+ from typing import Callable, Dict, Optional
8
+ from enum import Enum, auto
9
+ from einops import rearrange
10
+ from omegaconf import II
11
+
12
+ from .modules import get_2d_sincos_pos_embed_flexible, PatchEmbed_new
13
+
14
+
15
+ from .base import (
16
+ D2vModalityConfig,
17
+ ModalitySpecificEncoder,
18
+ get_alibi_bias,
19
+ )
20
+ from .modules import (
21
+ BlockEncoder,
22
+ FixedPositionalEncoder,
23
+ )
24
+
25
+
26
+ class Modality(Enum):
27
+ AUDIO = auto()
28
+ IMAGE = auto()
29
+ TEXT = auto()
30
+
31
+
32
+ @dataclass
33
+ class D2vImageConfig(D2vModalityConfig):
34
+ type: Modality = Modality.IMAGE
35
+
36
+ in_chans: int = 1
37
+ patch_size: int = 16
38
+ embed_dim: int = II('model.embed_dim')
39
+
40
+ alibi_dims: int = 2
41
+ alibi_distance: str = "manhattan"
42
+
43
+ fixed_positions: bool = True
44
+
45
+ transformer_decoder: bool = False
46
+ enc_dec_transformer: bool = False
47
+ target_length: int = 1024
48
+ max_length: int = 128
49
+ max_freq: int = 50
50
+
51
+ flatten: str = 'freq' # 'time', 'freq'
52
+
53
+
54
+ class ImageEncoder(ModalitySpecificEncoder):
55
+ # forward() implemented in models.base.ModalitySpecificEncoder
56
+
57
+ modality_cfg: D2vImageConfig
58
+
59
+ def __init__(
60
+ self,
61
+ modality_cfg: D2vImageConfig,
62
+ embed_dim: int,
63
+ make_block: Callable[[float, Optional[int], Optional[int]], nn.ModuleList],
64
+ norm_layer: Callable[[int], nn.LayerNorm],
65
+ layer_norm_first: bool,
66
+ alibi_biases: Dict,
67
+ task=None,
68
+ ):
69
+ self.patch_size = modality_cfg.patch_size
70
+ self.H = modality_cfg.target_length // self.patch_size # 64
71
+
72
+ # convert spec to patch embed, using conv1d
73
+ local_encoder = PatchEmbed_new(
74
+ patch_size=modality_cfg.patch_size, # 16
75
+ in_chans=modality_cfg.in_chans, # 1
76
+ embed_dim=modality_cfg.embed_dim, # 768
77
+ stride=modality_cfg.patch_size, # 16
78
+ flatten=modality_cfg.flatten
79
+ )
80
+
81
+ # CNN initialize
82
+ w = local_encoder.proj.weight.data
83
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
84
+
85
+ if modality_cfg.embed_dim != embed_dim:
86
+ local_encoder = nn.Sequential(
87
+ local_encoder,
88
+ nn.Linear(modality_cfg.embed_dim, embed_dim),
89
+ )
90
+
91
+ project_features = nn.Identity()
92
+
93
+ # note: max_length control the maximum time length of audio -> "64" for 10s, here we define it as 2min, you can change it yourself
94
+ max_length = modality_cfg.max_length
95
+ max_freq = modality_cfg.max_freq
96
+
97
+ # side_n = int(num_patches ** 0.5)
98
+ # note: we fix the variable length sequence problem here -> support up to 2min audio
99
+ emb = get_2d_sincos_pos_embed_flexible(
100
+ embed_dim,
101
+ (max_length, max_freq),
102
+ cls_token=False,
103
+ )
104
+ pos_embed = torch.from_numpy(emb[:max_length * max_freq, :]).float().unsqueeze(0)
105
+
106
+ fixed_positional_encoder = (
107
+ FixedPositionalEncoder(pos_embed) if modality_cfg.fixed_positions else None # True
108
+ )
109
+
110
+ dpr = np.linspace( # drop_path_rate
111
+ modality_cfg.start_drop_path_rate,
112
+ modality_cfg.end_drop_path_rate,
113
+ modality_cfg.prenet_depth, # actual: 0
114
+ )
115
+
116
+ # actual: only layer norm
117
+ context_encoder = BlockEncoder(
118
+ nn.ModuleList(make_block(dpr[i]) for i in range(modality_cfg.prenet_depth)),
119
+ norm_layer(embed_dim) if not layer_norm_first else None,
120
+ layer_norm_first,
121
+ modality_cfg.prenet_layerdrop,
122
+ modality_cfg.prenet_dropout,
123
+ )
124
+
125
+ alibi_bias_fn = partial(
126
+ get_alibi_bias,
127
+ alibi_biases=alibi_biases,
128
+ heads=modality_cfg.num_alibi_heads,
129
+ dims=modality_cfg.alibi_dims,
130
+ distance=modality_cfg.alibi_distance,
131
+ )
132
+
133
+ super().__init__(
134
+ modality_cfg=modality_cfg,
135
+ embed_dim=embed_dim,
136
+ local_encoder=local_encoder, # patch embed
137
+ project_features=project_features, # nn.Identity()
138
+ fixed_positional_encoder=fixed_positional_encoder,
139
+ relative_positional_encoder=None,
140
+ context_encoder=context_encoder, # apply mask
141
+ decoder=None,
142
+ get_alibi_bias=alibi_bias_fn,
143
+ )
144
+
145
+ def reset_parameters(self):
146
+ super().reset_parameters()
147
+
148
+ @torch.no_grad()
149
+ def patchify(self, imgs):
150
+ """
151
+ imgs: (N, 3, H, W) audio: (N,1,H,W) 1024/16 = 64 128/16 = 8
152
+ x: (N, L, patch_size**2 *3)
153
+ """
154
+ if self.modality_cfg.in_chans == 1: # actual: this one
155
+ p = self.modality_cfg.patch_size
156
+ h = imgs.shape[2] // p
157
+ w = imgs.shape[3] // p
158
+ # h,w = self.patch_embed.patch_hw
159
+ x = imgs.reshape(shape=(imgs.shape[0], 1, h, p, w, p))
160
+ x = torch.einsum('nchpwq->nhwpqc', x)
161
+ x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1))
162
+
163
+ else:
164
+ p = self.modality_cfg.patch_size
165
+ h = w = imgs.shape[2] // p
166
+ x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
167
+ x = torch.einsum("nchpwq->nhwpqc", x)
168
+ x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3))
169
+
170
+ return x
171
+
172
+ @torch.no_grad()
173
+ def unpatchify(self, x):
174
+ """
175
+ x: (N, L, patch_size**2 *C)
176
+ imgs: (N, C, H, W)
177
+ """
178
+ p = self.modality_cfg.patch_size
179
+ h = w = int(x.shape[1] ** 0.5) # num patch along two axis
180
+ assert h * w == x.shape[1]
181
+
182
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, -1))
183
+ x = torch.einsum("nhwpqc->nchpwq", x)
184
+ imgs = x.reshape(shape=(x.shape[0], -1, h * p, h * p))
185
+ return imgs
186
+
187
+ def convert_padding_mask(
188
+ self,
189
+ x: torch.Tensor,
190
+ padding_mask: torch.Tensor
191
+ ) -> torch.Tensor:
192
+ '''patchify and serialize padding_mask: [b,t,f] => [b,t_patch,f_patch] => [b,patch_seq]
193
+
194
+ Args:
195
+ x (torch.Tensor): input_features
196
+ padding_mask (torch.Tensor): [b,t_patch,f_patch], 1 for padded patch
197
+
198
+ Returns:
199
+ torch.Tensor: serialized padding mask. [b,patch_seq]
200
+ '''
201
+ B, T, F = x.shape
202
+ t_extra, f_extra = T % self.patch_size, F % self.patch_size
203
+ padding_mask = padding_mask[:, :-t_extra, :-f_extra]
204
+ padding_mask = rearrange(
205
+ padding_mask,
206
+ 'b (tp p) (fp q) -> b tp fp (p q)',
207
+ p=self.patch_size, q=self.patch_size
208
+ )
209
+ padding_mask = padding_mask.all(-1)
210
+
211
+ if self.modality_cfg.flatten == 'time':
212
+ padding_mask = padding_mask.transpose(-2, -1).flatten(1)
213
+ else:
214
+ padding_mask = padding_mask.flatten(1)
215
+ return padding_mask
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b8a96eda2821e72e1bf4e06d9171f7623b2d9e772de6596e6129a43955158cc5
3
+ size 38189336
modeling_fisher.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn as nn
4
+
5
+ from functools import partial
6
+ from einops import rearrange
7
+ from typing import Callable, Optional
8
+ from dataclasses import dataclass, field, is_dataclass
9
+ from transformers import PreTrainedModel
10
+
11
+ from .configuration_fisher import FISHERConfig
12
+ from .base import (
13
+ D2vModalityConfig,
14
+ ModalitySpecificEncoder,
15
+ )
16
+ from .modules import AltBlock
17
+ from .images import (
18
+ D2vImageConfig,
19
+ ImageEncoder,
20
+ )
21
+
22
+
23
+ @dataclass
24
+ class D2vModalitiesConfig:
25
+ image: D2vImageConfig = field(default_factory=lambda *args: D2vImageConfig())
26
+
27
+
28
+ @dataclass
29
+ class Data2VecMultiConfig:
30
+ depth: int = 12
31
+
32
+ # band split
33
+ band_width: int = 100
34
+
35
+ # standard vision Transformer
36
+ start_drop_path_rate: float = 0.0
37
+ end_drop_path_rate: float = 0.0
38
+ num_heads: int = 12
39
+ norm_eps: float = 1e-6
40
+ norm_affine: bool = True
41
+ encoder_dropout: float = 0.0
42
+ post_mlp_drop: float = 0.0
43
+ attention_dropout: float = 0.0
44
+ activation_dropout: float = 0.0
45
+ dropout_input: float = 0.0
46
+ layerdrop: float = 0.0
47
+ embed_dim: int = 768
48
+ mlp_ratio: float = 4.0
49
+ layer_norm_first: bool = False
50
+
51
+ end_of_block_targets: bool = False
52
+
53
+ # clone batch for multi-mask strategy
54
+ max_band_per_sample: int = 64
55
+
56
+ # normalization for teacher Transformer layer output
57
+ layer_norm_target_layer: bool = False
58
+ batch_norm_target_layer: bool = False
59
+ instance_norm_target_layer: bool = True
60
+ instance_norm_targets: bool = False
61
+ layer_norm_targets: bool = True
62
+
63
+ modalities: D2vModalitiesConfig = field(default_factory=lambda *args: D2vModalitiesConfig())
64
+
65
+
66
+ def update_dataclass(instance, data_dict):
67
+ if not data_dict:
68
+ return instance
69
+
70
+ for field_name, field_value in data_dict.items():
71
+ if hasattr(instance, field_name):
72
+ current_value = getattr(instance, field_name)
73
+ if is_dataclass(current_value) and isinstance(field_value, dict):
74
+ update_dataclass(current_value, field_value)
75
+ else:
76
+ setattr(instance, field_name, field_value)
77
+ return instance
78
+
79
+
80
+ class FISHER(nn.Module):
81
+ def __init__(self, config: FISHERConfig):
82
+ super().__init__()
83
+ cfg = Data2VecMultiConfig()
84
+ update_dataclass(cfg, config.to_dict())
85
+ cfg.modalities.image.embed_dim = cfg.embed_dim
86
+ cfg.modalities.image.embed_dim = cfg.embed_dim
87
+ self.cfg = cfg
88
+
89
+ make_layer_norm = partial(
90
+ nn.LayerNorm, eps=cfg.norm_eps, elementwise_affine=cfg.norm_affine
91
+ )
92
+
93
+ def make_block(drop_path, dim=None, heads=None):
94
+ return AltBlock(
95
+ cfg.embed_dim if dim is None else dim,
96
+ cfg.num_heads if heads is None else heads,
97
+ cfg.mlp_ratio,
98
+ qkv_bias=True,
99
+ drop=0.0,
100
+ attn_drop=cfg.attention_dropout,
101
+ mlp_drop=cfg.activation_dropout,
102
+ post_mlp_drop=cfg.post_mlp_drop,
103
+ drop_path=drop_path,
104
+ norm_layer=make_layer_norm,
105
+ layer_norm_first=cfg.layer_norm_first,
106
+ ffn_targets=not cfg.end_of_block_targets,
107
+ )
108
+
109
+ self.alibi_biases = {}
110
+ self.modality_encoders = nn.ModuleDict()
111
+
112
+ mod_cfg = getattr(cfg.modalities, 'image')
113
+ enc = self.make_modality_encoder(
114
+ mod_cfg,
115
+ cfg.embed_dim,
116
+ make_block,
117
+ make_layer_norm,
118
+ cfg.layer_norm_first,
119
+ self.alibi_biases,
120
+ )
121
+ self.modality_encoders['IMAGE'] = enc
122
+
123
+ dpr = np.linspace(cfg.start_drop_path_rate, cfg.end_drop_path_rate, cfg.depth)
124
+
125
+ self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.depth)])
126
+
127
+ self.norm = None
128
+ if cfg.layer_norm_first:
129
+ self.norm = make_layer_norm(cfg.embed_dim)
130
+
131
+ # band split
132
+ self.band_width = cfg.band_width
133
+ self.patch_size = cfg.modalities.image.patch_size
134
+
135
+ def make_modality_encoder(
136
+ self,
137
+ cfg: D2vModalityConfig,
138
+ embed_dim: int,
139
+ make_block: Callable[[float], nn.ModuleList],
140
+ norm_layer: Callable[[int], nn.LayerNorm],
141
+ layer_norm_first: bool,
142
+ alibi_biases,
143
+ task=None,
144
+ ) -> ModalitySpecificEncoder:
145
+ return ImageEncoder(
146
+ cfg,
147
+ embed_dim,
148
+ make_block,
149
+ norm_layer,
150
+ layer_norm_first,
151
+ alibi_biases,
152
+ task,
153
+ )
154
+
155
+ def forward(
156
+ self,
157
+ source: torch.Tensor,
158
+ target=None,
159
+ id=None,
160
+ mode='IMAGE',
161
+ padding_mask: Optional[torch.Tensor] = None,
162
+ mask: bool = True,
163
+ features_only: bool = False,
164
+ force_remove_masked=False,
165
+ precomputed_mask: Optional[torch.Tensor] = None,
166
+ ):
167
+ # band split
168
+ num_band = source.shape[-1] // self.band_width
169
+ source = torch.stack(source.split(self.band_width, dim=-1)[:num_band]) # drop residual
170
+ source = rearrange(source, 'nb B c t f -> (B nb) c t f')
171
+ clone_batch = self.cfg.max_band_per_sample // num_band
172
+
173
+ feature_extractor = self.modality_encoders[mode] # models.images.ImageEncoder
174
+
175
+ # extract (unmasked) features using CNN encoder
176
+ extractor_out = feature_extractor(
177
+ source,
178
+ padding_mask,
179
+ mask,
180
+ remove_masked=not features_only or force_remove_masked, # train: True; infer: False
181
+ clone_batch=clone_batch if not features_only else 1,
182
+ mask_seeds=None,
183
+ precomputed_mask=precomputed_mask,
184
+ )
185
+
186
+ # x in shape (batch_size * clone batch, patch_frame(64) * patch_freqency(8) * unmask_ratio(0.2) + 1(cls_token), 768(feature dimension))
187
+ x = extractor_out["x"]
188
+ # encoder_mask is applied on sub-band level
189
+ encoder_mask = extractor_out["encoder_mask"] # models.base.MaskInfo, ["x_unmasked", "mask", "ids_restore", "ids_keep"]
190
+ masked_padding_mask = extractor_out["padding_mask"]
191
+ masked_alibi_bias = extractor_out.get("alibi_bias", None)
192
+ alibi_scale = extractor_out.get("alibi_scale", None)
193
+
194
+ # standard Transformer (for student encoder)
195
+ layer_results = []
196
+ for i, blk in enumerate(self.blocks):
197
+ ab = masked_alibi_bias
198
+ if ab is not None and alibi_scale is not None:
199
+ scale = (
200
+ alibi_scale[i]
201
+ if alibi_scale.size(0) > 1
202
+ else alibi_scale.squeeze(0)
203
+ )
204
+ ab = ab * scale.type_as(ab)
205
+
206
+ x, lr = blk(
207
+ x,
208
+ padding_mask=masked_padding_mask,
209
+ alibi_bias=ab,
210
+ )
211
+ if features_only:
212
+ layer_results.append(lr)
213
+
214
+ if self.norm is not None:
215
+ x = self.norm(x)
216
+
217
+ # extract features for fine-tuning
218
+ if features_only:
219
+ return {
220
+ "x": x,
221
+ "padding_mask": masked_padding_mask,
222
+ "layer_results": layer_results,
223
+ "mask": encoder_mask,
224
+ }
225
+
226
+ def extract_features(
227
+ self, source, mode='IMAGE', padding_mask=None, mask=False
228
+ ):
229
+ num_band = source.shape[-1] // self.band_width
230
+ res = self.forward(
231
+ source,
232
+ mode=mode,
233
+ padding_mask=padding_mask,
234
+ mask=mask,
235
+ features_only=True,
236
+ )
237
+ x = res['x'][:, 0]
238
+ x = rearrange(x, '(B nb) D -> B (nb D)', nb=num_band)
239
+ return x
240
+
241
+
242
+ class FISHERModel(PreTrainedModel):
243
+ config_class = FISHERConfig
244
+
245
+ def __init__(self, cfg: FISHERConfig):
246
+ super().__init__(cfg)
247
+ self.cfg = cfg
248
+ self.model = FISHER(cfg)
249
+
250
+ def forward(self, *args, **kwargs):
251
+ return self.model(*args, **kwargs)
252
+
253
+ def extract_features(self, x):
254
+ return self.model.extract_features(x)
modules.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from timm.models.layers import to_2tuple
6
+
7
+
8
+ class PatchEmbed_new(nn.Module):
9
+ """ Flexible Image to Patch Embedding
10
+ """
11
+ def __init__(
12
+ self,
13
+ patch_size=16,
14
+ in_chans=3,
15
+ embed_dim=768,
16
+ stride=16,
17
+ flatten='freq'
18
+ ):
19
+ super().__init__()
20
+ self.flatten = flatten
21
+ patch_size = to_2tuple(patch_size)
22
+ stride = to_2tuple(stride)
23
+ assert flatten in ['time', 'freq']
24
+
25
+ self.patch_size = patch_size
26
+
27
+ # no padding for conv
28
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride) # with overlapped patches
29
+
30
+ def forward(self, x):
31
+ x = self.proj(x) # (B,768,64,8)
32
+ if self.flatten == 'freq':
33
+ x = x.flatten(2).transpose(1, 2) # flatten from dim
34
+ else:
35
+ x = x.transpose(-2, -1).flatten(2).transpose(1, 2)
36
+ return x
37
+
38
+
39
+ def get_2d_sincos_pos_embed_flexible(embed_dim, grid_size, cls_token=False):
40
+ """
41
+ grid_size: int of the grid height and width
42
+ return:
43
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
44
+ """
45
+ grid_h = np.arange(grid_size[0], dtype=np.float32)
46
+ grid_w = np.arange(grid_size[1], dtype=np.float32)
47
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
48
+ grid = np.stack(grid, axis=0)
49
+
50
+ grid = grid.reshape([2, 1, grid_size[0], grid_size[1]])
51
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
52
+ if cls_token:
53
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
54
+ return pos_embed
55
+
56
+
57
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
58
+ assert embed_dim % 2 == 0
59
+
60
+ # use half of dimensions to encode grid_h
61
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
62
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
63
+
64
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
65
+ return emb
66
+
67
+
68
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
69
+ """
70
+ embed_dim: output dimension for each position
71
+ pos: a list of positions to be encoded: size (M,)
72
+ out: (M, D)
73
+ """
74
+ assert embed_dim % 2 == 0
75
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
76
+ omega /= embed_dim / 2.0
77
+ omega = 1.0 / 10000 ** omega # (D/2,)
78
+
79
+ pos = pos.reshape(-1) # (M,)
80
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
81
+
82
+ emb_sin = np.sin(out) # (M, D/2)
83
+ emb_cos = np.cos(out) # (M, D/2)
84
+
85
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
86
+ return emb
87
+
88
+
89
+ class FixedPositionalEncoder(nn.Module):
90
+ def __init__(self, pos_embed: torch.Tensor):
91
+ super().__init__()
92
+ self.positions = pos_embed
93
+
94
+ def forward(self, x: torch.Tensor, padding_mask):
95
+ return self.positions.to(x.device)
96
+
97
+
98
+ class BlockEncoder(nn.Module):
99
+ def __init__(self, blocks, norm_layer, layer_norm_first, layerdrop, dropout):
100
+ super().__init__()
101
+ self.blocks = blocks
102
+ self.norm = norm_layer
103
+ self.layer_norm_first = layer_norm_first
104
+ self.layerdrop = layerdrop
105
+ self.dropout = nn.Dropout(dropout, inplace=True)
106
+
107
+ def forward(self, x, padding_mask, alibi_bias, alibi_scale):
108
+ if self.norm is not None and not self.layer_norm_first:
109
+ x = self.norm(x)
110
+
111
+ x = self.dropout(x)
112
+
113
+ for i, blk in enumerate(self.blocks):
114
+ if (
115
+ not self.training
116
+ or self.layerdrop == 0
117
+ or (np.random.random() > self.layerdrop)
118
+ ):
119
+ ab = alibi_bias
120
+ if ab is not None and alibi_scale is not None:
121
+ scale = (
122
+ alibi_scale[i]
123
+ if alibi_scale.size(0) > 1
124
+ else alibi_scale.squeeze(0)
125
+ )
126
+ ab = ab * scale.type_as(ab)
127
+ x, _ = blk(x, padding_mask, ab)
128
+
129
+ if self.norm is not None and self.layer_norm_first:
130
+ x = self.norm(x)
131
+
132
+ return x
133
+
134
+
135
+ class AltBlock(nn.Module):
136
+ def __init__(
137
+ self,
138
+ dim,
139
+ num_heads,
140
+ mlp_ratio=4.0,
141
+ qkv_bias=False,
142
+ qk_scale=None,
143
+ drop=0.0,
144
+ attn_drop=0.0,
145
+ mlp_drop=0.0,
146
+ post_mlp_drop=0.0,
147
+ drop_path=0.0,
148
+ act_layer=nn.GELU,
149
+ norm_layer=nn.LayerNorm,
150
+ layer_norm_first=True,
151
+ ffn_targets=False,
152
+ cosine_attention=False,
153
+ ):
154
+ super().__init__()
155
+
156
+ self.layer_norm_first = layer_norm_first
157
+ self.ffn_targets = ffn_targets
158
+
159
+ from timm.models.vision_transformer import DropPath, Mlp
160
+
161
+ self.norm1 = norm_layer(dim)
162
+ self.attn = AltAttention(
163
+ dim,
164
+ num_heads=num_heads,
165
+ qkv_bias=qkv_bias,
166
+ qk_scale=qk_scale,
167
+ attn_drop=attn_drop,
168
+ proj_drop=drop,
169
+ cosine_attention=cosine_attention,
170
+ )
171
+
172
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
173
+ self.norm2 = norm_layer(dim)
174
+ mlp_hidden_dim = int(dim * mlp_ratio)
175
+ self.mlp = Mlp(
176
+ in_features=dim,
177
+ hidden_features=mlp_hidden_dim,
178
+ act_layer=act_layer,
179
+ drop=mlp_drop,
180
+ )
181
+ self.post_mlp_dropout = nn.Dropout(post_mlp_drop, inplace=False)
182
+
183
+ def forward(self, x, padding_mask=None, alibi_bias=None):
184
+ if self.layer_norm_first:
185
+ x = x + self.drop_path(self.attn(self.norm1(x), padding_mask, alibi_bias))
186
+ r = x = self.mlp(self.norm2(x))
187
+ t = x
188
+ x = r + self.drop_path(self.post_mlp_dropout(x))
189
+ if not self.ffn_targets:
190
+ t = x
191
+ else:
192
+ x = x + self.drop_path(self.attn(x, padding_mask, alibi_bias))
193
+ r = x = self.norm1(x)
194
+ x = self.mlp(x)
195
+ t = x
196
+ x = self.norm2(r + self.drop_path(self.post_mlp_dropout(x)))
197
+ if not self.ffn_targets:
198
+ t = x
199
+
200
+ return x, t
201
+
202
+
203
+ class AltAttention(nn.Module):
204
+ def __init__(
205
+ self,
206
+ dim,
207
+ num_heads=8,
208
+ qkv_bias=False,
209
+ qk_scale=None,
210
+ attn_drop=0.0,
211
+ proj_drop=0.0,
212
+ cosine_attention=False,
213
+ ):
214
+ super().__init__()
215
+ self.num_heads = num_heads
216
+ head_dim = dim // num_heads
217
+ self.scale = qk_scale or head_dim ** -0.5
218
+
219
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
220
+ self.attn_drop = nn.Dropout(attn_drop)
221
+ self.proj = nn.Linear(dim, dim)
222
+ self.proj_drop = nn.Dropout(proj_drop)
223
+
224
+ self.cosine_attention = cosine_attention
225
+
226
+ if cosine_attention:
227
+ self.logit_scale = nn.Parameter(
228
+ torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True
229
+ )
230
+
231
+ def forward(self, x, padding_mask=None, alibi_bias=None):
232
+ B, N, C = x.shape
233
+ qkv = (
234
+ self.qkv(x)
235
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
236
+ .permute(2, 0, 3, 1, 4) # qkv x B x H x L x D
237
+ )
238
+ q, k, v = (
239
+ qkv[0],
240
+ qkv[1],
241
+ qkv[2],
242
+ ) # make torchscript happy (cannot use tensor as tuple)
243
+
244
+ dtype = q.dtype
245
+
246
+ if self.cosine_attention:
247
+ # cosine attention
248
+ attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
249
+ logit_scale = torch.clamp(
250
+ self.logit_scale, max=torch.log(torch.tensor(1.0 / 0.01))
251
+ ).exp()
252
+ attn = attn * logit_scale
253
+ else:
254
+ q = q * self.scale
255
+ attn = q @ k.transpose(-2, -1)
256
+
257
+ if alibi_bias is not None:
258
+ attn = attn.type_as(alibi_bias)
259
+ attn[:, : alibi_bias.size(1)] += alibi_bias
260
+
261
+ if padding_mask is not None and padding_mask.any():
262
+ attn = attn.masked_fill(
263
+ padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
264
+ float("-inf"),
265
+ )
266
+
267
+ attn = attn.softmax(dim=-1, dtype=torch.float32).to(dtype=dtype)
268
+ attn = self.attn_drop(attn)
269
+ x = (attn @ v).transpose(1, 2) #
270
+ x = x.reshape(B, N, C)
271
+ x = self.proj(x)
272
+ x = self.proj_drop(x)
273
+ return x