|
import torch
|
|
import numpy as np
|
|
import torch.nn as nn
|
|
|
|
from functools import partial
|
|
from einops import rearrange
|
|
from typing import Callable, Optional
|
|
from dataclasses import dataclass, field, is_dataclass
|
|
from transformers import PreTrainedModel
|
|
|
|
from .configuration_fisher import FISHERConfig
|
|
from .base import (
|
|
D2vModalityConfig,
|
|
ModalitySpecificEncoder,
|
|
)
|
|
from .modules import AltBlock
|
|
from .images import (
|
|
D2vImageConfig,
|
|
ImageEncoder,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class D2vModalitiesConfig:
|
|
image: D2vImageConfig = field(default_factory=lambda *args: D2vImageConfig())
|
|
|
|
|
|
@dataclass
|
|
class Data2VecMultiConfig:
|
|
depth: int = 12
|
|
|
|
|
|
band_width: int = 100
|
|
|
|
|
|
start_drop_path_rate: float = 0.0
|
|
end_drop_path_rate: float = 0.0
|
|
num_heads: int = 12
|
|
norm_eps: float = 1e-6
|
|
norm_affine: bool = True
|
|
encoder_dropout: float = 0.0
|
|
post_mlp_drop: float = 0.0
|
|
attention_dropout: float = 0.0
|
|
activation_dropout: float = 0.0
|
|
dropout_input: float = 0.0
|
|
layerdrop: float = 0.0
|
|
embed_dim: int = 768
|
|
mlp_ratio: float = 4.0
|
|
layer_norm_first: bool = False
|
|
|
|
end_of_block_targets: bool = False
|
|
|
|
|
|
max_band_per_sample: int = 64
|
|
|
|
|
|
layer_norm_target_layer: bool = False
|
|
batch_norm_target_layer: bool = False
|
|
instance_norm_target_layer: bool = True
|
|
instance_norm_targets: bool = False
|
|
layer_norm_targets: bool = True
|
|
|
|
modalities: D2vModalitiesConfig = field(default_factory=lambda *args: D2vModalitiesConfig())
|
|
|
|
|
|
def update_dataclass(instance, data_dict):
|
|
if not data_dict:
|
|
return instance
|
|
|
|
for field_name, field_value in data_dict.items():
|
|
if hasattr(instance, field_name):
|
|
current_value = getattr(instance, field_name)
|
|
if is_dataclass(current_value) and isinstance(field_value, dict):
|
|
update_dataclass(current_value, field_value)
|
|
else:
|
|
setattr(instance, field_name, field_value)
|
|
return instance
|
|
|
|
|
|
class FISHER(nn.Module):
|
|
def __init__(self, config: FISHERConfig):
|
|
super().__init__()
|
|
cfg = Data2VecMultiConfig()
|
|
update_dataclass(cfg, config.to_dict())
|
|
cfg.modalities.image.embed_dim = cfg.embed_dim
|
|
cfg.modalities.image.embed_dim = cfg.embed_dim
|
|
self.cfg = cfg
|
|
|
|
make_layer_norm = partial(
|
|
nn.LayerNorm, eps=cfg.norm_eps, elementwise_affine=cfg.norm_affine
|
|
)
|
|
|
|
def make_block(drop_path, dim=None, heads=None):
|
|
return AltBlock(
|
|
cfg.embed_dim if dim is None else dim,
|
|
cfg.num_heads if heads is None else heads,
|
|
cfg.mlp_ratio,
|
|
qkv_bias=True,
|
|
drop=0.0,
|
|
attn_drop=cfg.attention_dropout,
|
|
mlp_drop=cfg.activation_dropout,
|
|
post_mlp_drop=cfg.post_mlp_drop,
|
|
drop_path=drop_path,
|
|
norm_layer=make_layer_norm,
|
|
layer_norm_first=cfg.layer_norm_first,
|
|
ffn_targets=not cfg.end_of_block_targets,
|
|
)
|
|
|
|
self.alibi_biases = {}
|
|
self.modality_encoders = nn.ModuleDict()
|
|
|
|
mod_cfg = getattr(cfg.modalities, 'image')
|
|
enc = self.make_modality_encoder(
|
|
mod_cfg,
|
|
cfg.embed_dim,
|
|
make_block,
|
|
make_layer_norm,
|
|
cfg.layer_norm_first,
|
|
self.alibi_biases,
|
|
)
|
|
self.modality_encoders['IMAGE'] = enc
|
|
|
|
dpr = np.linspace(cfg.start_drop_path_rate, cfg.end_drop_path_rate, cfg.depth)
|
|
|
|
self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.depth)])
|
|
|
|
self.norm = None
|
|
if cfg.layer_norm_first:
|
|
self.norm = make_layer_norm(cfg.embed_dim)
|
|
|
|
|
|
self.band_width = cfg.band_width
|
|
self.patch_size = cfg.modalities.image.patch_size
|
|
|
|
def make_modality_encoder(
|
|
self,
|
|
cfg: D2vModalityConfig,
|
|
embed_dim: int,
|
|
make_block: Callable[[float], nn.ModuleList],
|
|
norm_layer: Callable[[int], nn.LayerNorm],
|
|
layer_norm_first: bool,
|
|
alibi_biases,
|
|
task=None,
|
|
) -> ModalitySpecificEncoder:
|
|
return ImageEncoder(
|
|
cfg,
|
|
embed_dim,
|
|
make_block,
|
|
norm_layer,
|
|
layer_norm_first,
|
|
alibi_biases,
|
|
task,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
source: torch.Tensor,
|
|
target=None,
|
|
id=None,
|
|
mode='IMAGE',
|
|
padding_mask: Optional[torch.Tensor] = None,
|
|
mask: bool = True,
|
|
features_only: bool = False,
|
|
force_remove_masked=False,
|
|
precomputed_mask: Optional[torch.Tensor] = None,
|
|
):
|
|
|
|
num_band = source.shape[-1] // self.band_width
|
|
source = torch.stack(source.split(self.band_width, dim=-1)[:num_band])
|
|
source = rearrange(source, 'nb B c t f -> (B nb) c t f')
|
|
clone_batch = self.cfg.max_band_per_sample // num_band
|
|
|
|
feature_extractor = self.modality_encoders[mode]
|
|
|
|
|
|
extractor_out = feature_extractor(
|
|
source,
|
|
padding_mask,
|
|
mask,
|
|
remove_masked=not features_only or force_remove_masked,
|
|
clone_batch=clone_batch if not features_only else 1,
|
|
mask_seeds=None,
|
|
precomputed_mask=precomputed_mask,
|
|
)
|
|
|
|
|
|
x = extractor_out["x"]
|
|
|
|
encoder_mask = extractor_out["encoder_mask"]
|
|
masked_padding_mask = extractor_out["padding_mask"]
|
|
masked_alibi_bias = extractor_out.get("alibi_bias", None)
|
|
alibi_scale = extractor_out.get("alibi_scale", None)
|
|
|
|
|
|
layer_results = []
|
|
for i, blk in enumerate(self.blocks):
|
|
ab = masked_alibi_bias
|
|
if ab is not None and alibi_scale is not None:
|
|
scale = (
|
|
alibi_scale[i]
|
|
if alibi_scale.size(0) > 1
|
|
else alibi_scale.squeeze(0)
|
|
)
|
|
ab = ab * scale.type_as(ab)
|
|
|
|
x, lr = blk(
|
|
x,
|
|
padding_mask=masked_padding_mask,
|
|
alibi_bias=ab,
|
|
)
|
|
if features_only:
|
|
layer_results.append(lr)
|
|
|
|
if self.norm is not None:
|
|
x = self.norm(x)
|
|
|
|
|
|
if features_only:
|
|
return {
|
|
"x": x,
|
|
"padding_mask": masked_padding_mask,
|
|
"layer_results": layer_results,
|
|
"mask": encoder_mask,
|
|
}
|
|
|
|
def extract_features(
|
|
self, source, mode='IMAGE', padding_mask=None, mask=False
|
|
):
|
|
num_band = source.shape[-1] // self.band_width
|
|
res = self.forward(
|
|
source,
|
|
mode=mode,
|
|
padding_mask=padding_mask,
|
|
mask=mask,
|
|
features_only=True,
|
|
)
|
|
x = res['x'][:, 0]
|
|
x = rearrange(x, '(B nb) D -> B (nb D)', nb=num_band)
|
|
return x
|
|
|
|
|
|
class FISHERModel(PreTrainedModel):
|
|
config_class = FISHERConfig
|
|
|
|
def __init__(self, cfg: FISHERConfig):
|
|
super().__init__(cfg)
|
|
self.cfg = cfg
|
|
self.model = FISHER(cfg)
|
|
|
|
def forward(self, *args, **kwargs):
|
|
return self.model(*args, **kwargs)
|
|
|
|
def extract_features(self, x):
|
|
return self.model.extract_features(x)
|
|
|