FISHER-mini-0723 / modeling_fisher.py
jiangab's picture
Upload folder using huggingface_hub
8960e0d verified
import torch
import numpy as np
import torch.nn as nn
from functools import partial
from einops import rearrange
from typing import Callable, Optional
from dataclasses import dataclass, field, is_dataclass
from transformers import PreTrainedModel
from .configuration_fisher import FISHERConfig
from .base import (
D2vModalityConfig,
ModalitySpecificEncoder,
)
from .modules import AltBlock
from .images import (
D2vImageConfig,
ImageEncoder,
)
@dataclass
class D2vModalitiesConfig:
image: D2vImageConfig = field(default_factory=lambda *args: D2vImageConfig())
@dataclass
class Data2VecMultiConfig:
depth: int = 12
# band split
band_width: int = 100
# standard vision Transformer
start_drop_path_rate: float = 0.0
end_drop_path_rate: float = 0.0
num_heads: int = 12
norm_eps: float = 1e-6
norm_affine: bool = True
encoder_dropout: float = 0.0
post_mlp_drop: float = 0.0
attention_dropout: float = 0.0
activation_dropout: float = 0.0
dropout_input: float = 0.0
layerdrop: float = 0.0
embed_dim: int = 768
mlp_ratio: float = 4.0
layer_norm_first: bool = False
end_of_block_targets: bool = False
# clone batch for multi-mask strategy
max_band_per_sample: int = 64
# normalization for teacher Transformer layer output
layer_norm_target_layer: bool = False
batch_norm_target_layer: bool = False
instance_norm_target_layer: bool = True
instance_norm_targets: bool = False
layer_norm_targets: bool = True
modalities: D2vModalitiesConfig = field(default_factory=lambda *args: D2vModalitiesConfig())
def update_dataclass(instance, data_dict):
if not data_dict:
return instance
for field_name, field_value in data_dict.items():
if hasattr(instance, field_name):
current_value = getattr(instance, field_name)
if is_dataclass(current_value) and isinstance(field_value, dict):
update_dataclass(current_value, field_value)
else:
setattr(instance, field_name, field_value)
return instance
class FISHER(nn.Module):
def __init__(self, config: FISHERConfig):
super().__init__()
cfg = Data2VecMultiConfig()
update_dataclass(cfg, config.to_dict())
cfg.modalities.image.embed_dim = cfg.embed_dim
cfg.modalities.image.embed_dim = cfg.embed_dim
self.cfg = cfg
make_layer_norm = partial(
nn.LayerNorm, eps=cfg.norm_eps, elementwise_affine=cfg.norm_affine
)
def make_block(drop_path, dim=None, heads=None):
return AltBlock(
cfg.embed_dim if dim is None else dim,
cfg.num_heads if heads is None else heads,
cfg.mlp_ratio,
qkv_bias=True,
drop=0.0,
attn_drop=cfg.attention_dropout,
mlp_drop=cfg.activation_dropout,
post_mlp_drop=cfg.post_mlp_drop,
drop_path=drop_path,
norm_layer=make_layer_norm,
layer_norm_first=cfg.layer_norm_first,
ffn_targets=not cfg.end_of_block_targets,
)
self.alibi_biases = {}
self.modality_encoders = nn.ModuleDict()
mod_cfg = getattr(cfg.modalities, 'image')
enc = self.make_modality_encoder(
mod_cfg,
cfg.embed_dim,
make_block,
make_layer_norm,
cfg.layer_norm_first,
self.alibi_biases,
)
self.modality_encoders['IMAGE'] = enc
dpr = np.linspace(cfg.start_drop_path_rate, cfg.end_drop_path_rate, cfg.depth)
self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.depth)])
self.norm = None
if cfg.layer_norm_first:
self.norm = make_layer_norm(cfg.embed_dim)
# band split
self.band_width = cfg.band_width
self.patch_size = cfg.modalities.image.patch_size
def make_modality_encoder(
self,
cfg: D2vModalityConfig,
embed_dim: int,
make_block: Callable[[float], nn.ModuleList],
norm_layer: Callable[[int], nn.LayerNorm],
layer_norm_first: bool,
alibi_biases,
task=None,
) -> ModalitySpecificEncoder:
return ImageEncoder(
cfg,
embed_dim,
make_block,
norm_layer,
layer_norm_first,
alibi_biases,
task,
)
def forward(
self,
source: torch.Tensor,
target=None,
id=None,
mode='IMAGE',
padding_mask: Optional[torch.Tensor] = None,
mask: bool = True,
features_only: bool = False,
force_remove_masked=False,
precomputed_mask: Optional[torch.Tensor] = None,
):
# band split
num_band = source.shape[-1] // self.band_width
source = torch.stack(source.split(self.band_width, dim=-1)[:num_band]) # drop residual
source = rearrange(source, 'nb B c t f -> (B nb) c t f')
clone_batch = self.cfg.max_band_per_sample // num_band
feature_extractor = self.modality_encoders[mode] # models.images.ImageEncoder
# extract (unmasked) features using CNN encoder
extractor_out = feature_extractor(
source,
padding_mask,
mask,
remove_masked=not features_only or force_remove_masked, # train: True; infer: False
clone_batch=clone_batch if not features_only else 1,
mask_seeds=None,
precomputed_mask=precomputed_mask,
)
# x in shape (batch_size * clone batch, patch_frame(64) * patch_freqency(8) * unmask_ratio(0.2) + 1(cls_token), 768(feature dimension))
x = extractor_out["x"]
# encoder_mask is applied on sub-band level
encoder_mask = extractor_out["encoder_mask"] # models.base.MaskInfo, ["x_unmasked", "mask", "ids_restore", "ids_keep"]
masked_padding_mask = extractor_out["padding_mask"]
masked_alibi_bias = extractor_out.get("alibi_bias", None)
alibi_scale = extractor_out.get("alibi_scale", None)
# standard Transformer (for student encoder)
layer_results = []
for i, blk in enumerate(self.blocks):
ab = masked_alibi_bias
if ab is not None and alibi_scale is not None:
scale = (
alibi_scale[i]
if alibi_scale.size(0) > 1
else alibi_scale.squeeze(0)
)
ab = ab * scale.type_as(ab)
x, lr = blk(
x,
padding_mask=masked_padding_mask,
alibi_bias=ab,
)
if features_only:
layer_results.append(lr)
if self.norm is not None:
x = self.norm(x)
# extract features for fine-tuning
if features_only:
return {
"x": x,
"padding_mask": masked_padding_mask,
"layer_results": layer_results,
"mask": encoder_mask,
}
def extract_features(
self, source, mode='IMAGE', padding_mask=None, mask=False
):
num_band = source.shape[-1] // self.band_width
res = self.forward(
source,
mode=mode,
padding_mask=padding_mask,
mask=mask,
features_only=True,
)
x = res['x'][:, 0]
x = rearrange(x, '(B nb) D -> B (nb D)', nb=num_band)
return x
class FISHERModel(PreTrainedModel):
config_class = FISHERConfig
def __init__(self, cfg: FISHERConfig):
super().__init__(cfg)
self.cfg = cfg
self.model = FISHER(cfg)
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
def extract_features(self, x):
return self.model.extract_features(x)