|
import math |
|
import torch |
|
import numpy as np |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from detectron2.layers import Conv2d |
|
import fvcore.nn.weight_init as weight_init |
|
from typing import Any, Optional, Tuple, Type |
|
|
|
from modeling.semantic_enhanced_matting.modeling.image_encoder import Attention |
|
from modeling.semantic_enhanced_matting.modeling.transformer import Attention as DownAttention |
|
from modeling.semantic_enhanced_matting.feature_fusion import PositionEmbeddingRandom as ImagePositionEmbedding |
|
from modeling.semantic_enhanced_matting.modeling.common import MLPBlock |
|
|
|
class LayerNorm2d(nn.Module): |
|
def __init__(self, num_channels: int, eps: float = 1e-6) -> None: |
|
super().__init__() |
|
self.weight = nn.Parameter(torch.ones(num_channels)) |
|
self.bias = nn.Parameter(torch.zeros(num_channels)) |
|
self.eps = eps |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
u = x.mean(1, keepdim=True) |
|
s = (x - u).pow(2).mean(1, keepdim=True) |
|
x = (x - u) / torch.sqrt(s + self.eps) |
|
x = self.weight[:, None, None] * x + self.bias[:, None, None] |
|
return x |
|
|
|
|
|
class ConditionConv(nn.Module): |
|
""" |
|
The standard bottleneck residual block without the last activation layer. |
|
It contains 3 conv layers with kernels 1x1, 3x3, 1x1. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_channels, |
|
out_channels, |
|
bottleneck_channels, |
|
norm=LayerNorm2d, |
|
act_layer=nn.GELU, |
|
conv_kernels=3, |
|
conv_paddings=1, |
|
condtition_channels = 1024 |
|
): |
|
""" |
|
Args: |
|
in_channels (int): Number of input channels. |
|
out_channels (int): Number of output channels. |
|
bottleneck_channels (int): number of output channels for the 3x3 |
|
"bottleneck" conv layers. |
|
norm (str or callable): normalization for all conv layers. |
|
See :func:`layers.get_norm` for supported format. |
|
act_layer (callable): activation for all conv layers. |
|
""" |
|
super().__init__() |
|
|
|
self.conv1 = Conv2d(in_channels, bottleneck_channels, 1, bias=False) |
|
self.norm1 = norm(bottleneck_channels) |
|
self.act1 = act_layer() |
|
|
|
self.conv2 = Conv2d( |
|
bottleneck_channels, |
|
bottleneck_channels, |
|
conv_kernels, |
|
padding=conv_paddings, |
|
bias=False, |
|
) |
|
self.norm2 = norm(bottleneck_channels) |
|
self.act2 = act_layer() |
|
|
|
self.conv3 = Conv2d(bottleneck_channels, out_channels, 1, bias=False) |
|
self.norm3 = norm(out_channels) |
|
|
|
self.init_weight() |
|
|
|
self.condition_embedding = nn.Sequential( |
|
act_layer(), |
|
nn.Linear(condtition_channels, bottleneck_channels, bias=True) |
|
) |
|
|
|
def init_weight(self): |
|
for layer in [self.conv1, self.conv2, self.conv3]: |
|
weight_init.c2_msra_fill(layer) |
|
for layer in [self.norm1, self.norm2]: |
|
layer.weight.data.fill_(1.0) |
|
layer.bias.data.zero_() |
|
|
|
self.norm3.weight.data.zero_() |
|
self.norm3.bias.data.zero_() |
|
|
|
|
|
|
|
|
|
def forward(self, x, condition): |
|
|
|
out = x.permute(0, 3, 1, 2) |
|
|
|
out = self.act1(self.norm1(self.conv1(out))) |
|
out = self.conv2(out) + self.condition_embedding(condition)[:, :, None, None] |
|
out = self.act2(self.norm2(out)) |
|
out = self.norm3(self.conv3(out)) |
|
|
|
out = x + out.permute(0, 2, 3, 1) |
|
return out |
|
|
|
|
|
class ConditionAdd(nn.Module): |
|
def __init__( |
|
self, |
|
act_layer=nn.GELU, |
|
condtition_channels = 1024 |
|
): |
|
super().__init__() |
|
|
|
self.condition_embedding = nn.Sequential( |
|
act_layer(), |
|
nn.Linear(condtition_channels, condtition_channels, bias=True) |
|
) |
|
|
|
def forward(self, x, condition): |
|
|
|
condition = self.condition_embedding(condition)[:, None, None, :] |
|
return x + condition |
|
|
|
class ConditionEmbedding(nn.Module): |
|
def __init__( |
|
self, |
|
condition_num = 5, |
|
pos_embedding_dim = 128, |
|
embedding_scale = 1.0, |
|
embedding_max_period = 10000, |
|
embedding_flip_sin_to_cos = True, |
|
embedding_downscale_freq_shift = 1.0, |
|
time_embed_dim = 1024, |
|
split_embed = False |
|
): |
|
super().__init__() |
|
self.condition_num = condition_num |
|
self.pos_embedding_dim = pos_embedding_dim |
|
self.embedding_scale = embedding_scale |
|
self.embedding_max_period = embedding_max_period |
|
self.embedding_flip_sin_to_cos = embedding_flip_sin_to_cos |
|
self.embedding_downscale_freq_shift = embedding_downscale_freq_shift |
|
self.split_embed = split_embed |
|
|
|
if self.split_embed: |
|
self.linear_1 = nn.Linear(pos_embedding_dim, time_embed_dim, True) |
|
else: |
|
self.linear_1 = nn.Linear(condition_num * pos_embedding_dim, time_embed_dim, True) |
|
self.act = nn.GELU() |
|
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim, True) |
|
|
|
def proj_embedding(self, condition): |
|
sample = self.linear_1(condition) |
|
sample = self.act(sample) |
|
sample = self.linear_2(sample) |
|
return sample |
|
|
|
def position_embedding(self, condition): |
|
|
|
|
|
assert condition.shape[-1] == self.condition_num |
|
|
|
half_dim = self.pos_embedding_dim // 2 |
|
exponent = -math.log(self.embedding_max_period) * torch.arange( |
|
start=0, end=half_dim, dtype=torch.float32, device=condition.device |
|
) |
|
exponent = exponent / (half_dim - self.embedding_downscale_freq_shift) |
|
|
|
emb = torch.exp(exponent) |
|
emb = condition[:, :, None].float() * emb[None, None, :] |
|
|
|
|
|
emb = self.embedding_scale * emb |
|
|
|
|
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) |
|
|
|
|
|
if self.embedding_flip_sin_to_cos: |
|
emb = torch.cat([emb[:, :, half_dim:], emb[:, :, :half_dim]], dim=-1) |
|
|
|
|
|
|
|
|
|
if self.split_embed: |
|
emb = emb.reshape(-1, emb.shape[-1]) |
|
else: |
|
emb = emb.reshape(emb.shape[0], -1) |
|
|
|
return emb |
|
|
|
def forward(self, condition): |
|
condition = self.position_embedding(condition) |
|
condition = self.proj_embedding(condition) |
|
return condition.float() |
|
|
|
|
|
|
|
class PositionEmbeddingRandom(nn.Module): |
|
""" |
|
Positional encoding using random spatial frequencies. |
|
""" |
|
|
|
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: |
|
super().__init__() |
|
if scale is None or scale <= 0.0: |
|
scale = 1.0 |
|
|
|
self.positional_encoding_gaussian_matrix = nn.Parameter(scale * torch.randn((2, num_pos_feats // 2))) |
|
|
|
|
|
|
|
|
|
point_embeddings = [nn.Embedding(1, num_pos_feats) for i in range(2)] |
|
self.point_embeddings = nn.ModuleList(point_embeddings) |
|
|
|
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: |
|
"""Positionally encode points that are normalized to [0,1].""" |
|
|
|
coords = 2 * coords - 1 |
|
coords = coords @ self.positional_encoding_gaussian_matrix |
|
coords = 2 * np.pi * coords |
|
|
|
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) |
|
|
|
def forward( |
|
self, coords_input: torch.Tensor, image_size: Tuple[int, int] |
|
) -> torch.Tensor: |
|
"""Positionally encode points that are not normalized to [0,1].""" |
|
coords = coords_input.clone() |
|
coords[:, :, 0] = coords[:, :, 0] / image_size[1] |
|
coords[:, :, 1] = coords[:, :, 1] / image_size[0] |
|
coords = self._pe_encoding(coords.to(torch.float)) |
|
|
|
coords[:, 0, :] += self.point_embeddings[0].weight |
|
coords[:, 1, :] += self.point_embeddings[1].weight |
|
|
|
return coords |
|
|
|
|
|
class CrossSelfAttn(nn.Module): |
|
""" |
|
Positional encoding using random spatial frequencies. |
|
""" |
|
|
|
def __init__(self, embedding_dim=1024, num_heads=4, downsample_rate=4) -> None: |
|
super().__init__() |
|
|
|
self.cross_attn = DownAttention(embedding_dim=embedding_dim, num_heads=num_heads, downsample_rate=downsample_rate) |
|
self.norm1 = nn.LayerNorm(embedding_dim) |
|
self.mlp = MLPBlock(embedding_dim, mlp_dim=512) |
|
self.norm2 = nn.LayerNorm(embedding_dim) |
|
self.self_attn = DownAttention(embedding_dim=embedding_dim, num_heads=num_heads, downsample_rate=downsample_rate) |
|
self.norm3 = nn.LayerNorm(embedding_dim) |
|
|
|
def forward(self, block_feat, bbox_token, feat_pe, bbox_pe): |
|
B, H, W, C = block_feat.shape |
|
block_feat = block_feat.reshape(B, H * W, C) |
|
|
|
block_feat = block_feat + self.cross_attn(q=block_feat + feat_pe, k=bbox_token + bbox_pe, v=bbox_token) |
|
block_feat = self.norm1(block_feat) |
|
|
|
block_feat = block_feat + self.mlp(block_feat) |
|
block_feat = self.norm2(block_feat) |
|
|
|
concat_token = torch.concat((block_feat + feat_pe, bbox_token + bbox_pe), dim=1) |
|
block_feat = block_feat + self.self_attn(q=concat_token, k=concat_token, v=concat_token)[:, :-bbox_token.shape[1]] |
|
block_feat = self.norm3(block_feat) |
|
output = block_feat.reshape(B, H, W, C) |
|
|
|
return output |
|
|
|
|
|
class BBoxEmbedInteract(nn.Module): |
|
def __init__( |
|
self, |
|
embed_type = 'fourier', |
|
interact_type = 'attn', |
|
layer_num = 3 |
|
): |
|
super().__init__() |
|
assert embed_type in {'fourier', 'position', 'conv'} |
|
assert interact_type in {'add', 'attn', 'cross-self-attn'} |
|
self.embed_type = embed_type |
|
self.interact_type = interact_type |
|
self.layer_num = layer_num |
|
|
|
if self.embed_type == 'fourier' and self.interact_type == 'add': |
|
self.embed_layer = ConditionEmbedding(condition_num = 4, pos_embedding_dim = 256) |
|
elif self.embed_type == 'fourier': |
|
self.embed_layer = ConditionEmbedding(condition_num = 4, pos_embedding_dim = 256, split_embed = True) |
|
elif self.embed_type == 'conv': |
|
mask_in_chans = 16 |
|
activation = nn.GELU |
|
self.embed_layer = nn.Sequential( |
|
nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), |
|
LayerNorm2d(mask_in_chans // 4), |
|
activation(), |
|
nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), |
|
LayerNorm2d(mask_in_chans), |
|
activation(), |
|
nn.Conv2d(mask_in_chans, 1024, kernel_size=1), |
|
) |
|
else: |
|
if self.interact_type == 'add': |
|
self.embed_layer = PositionEmbeddingRandom(num_pos_feats = 512) |
|
else: |
|
self.embed_layer = PositionEmbeddingRandom(num_pos_feats = 1024) |
|
|
|
self.interact_layer = nn.ModuleList() |
|
for _ in range(self.layer_num): |
|
if self.interact_type == 'attn': |
|
self.interact_layer.append(Attention(dim = 1024)) |
|
elif self.interact_type == 'add' and self.embed_type != 'conv': |
|
self.interact_layer.append(nn.Sequential( |
|
nn.GELU(), |
|
nn.Linear(1024, 1024, bias=True) |
|
)) |
|
elif self.interact_type == 'cross-self-attn': |
|
self.interact_layer.append(CrossSelfAttn(embedding_dim=1024, num_heads=4, downsample_rate=4)) |
|
|
|
self.position_layer = ImagePositionEmbedding(num_pos_feats=1024 // 2) |
|
|
|
def forward(self, block_feat, bbox, layer_index): |
|
|
|
if layer_index == self.layer_num: |
|
return block_feat |
|
interact_layer = self.interact_layer[layer_index] |
|
|
|
bbox = bbox + 0.5 |
|
if self.embed_type == 'fourier' and self.interact_type == 'add': |
|
embedding = self.embed_layer(bbox[:, 0]) |
|
embedding = embedding.reshape(embedding.shape[0], 1, -1) |
|
elif self.embed_type == 'fourier': |
|
embedding = self.embed_layer(bbox[:, 0]) |
|
embedding = embedding.reshape(-1, 4, embedding.shape[-1]) |
|
elif self.embed_type == 'conv': |
|
|
|
bbox_mask = torch.zeros(size=(block_feat.shape[0], 1, 256, 256), device=block_feat.device, dtype=block_feat.dtype) |
|
for i in range(bbox.shape[0]): |
|
l, u, r, d = bbox[i, 0, :] / 4 |
|
bbox_mask[i, :, int(u + 0.5): int(d + 0.5), int(l + 0.5): int(r + 0.5)] = 1.0 |
|
embedding = self.embed_layer(bbox_mask) |
|
elif self.embed_type == 'position': |
|
embedding = self.embed_layer(bbox.reshape(-1, 2, 2), (1024, 1024)) |
|
if self.interact_type == 'add': |
|
embedding = embedding.reshape(embedding.shape[0], 1, -1) |
|
|
|
|
|
pe = self.position_layer(size=(64, 64)).reshape(1, 64, 64, 1024) |
|
block_feat = block_feat + pe |
|
|
|
if self.interact_type == 'attn': |
|
add_token_num = embedding.shape[1] |
|
B, H, W, C = block_feat.shape |
|
block_feat = block_feat.reshape(B, H * W, C) |
|
concat_token = torch.concat((block_feat, embedding), dim=1) |
|
output_token = interact_layer.forward_token(concat_token)[:, :-add_token_num] |
|
output = output_token.reshape(B, H, W, C) |
|
elif self.embed_type == 'conv': |
|
output = block_feat + embedding.permute(0, 2, 3, 1) |
|
elif self.interact_type == 'add': |
|
output = interact_layer(embedding[:, None]) + block_feat |
|
elif self.interact_type == 'cross-self-attn': |
|
output = interact_layer(block_feat, embedding) |
|
|
|
return output |
|
|
|
|
|
|
|
class BBoxInteract(nn.Module): |
|
def __init__( |
|
self, |
|
position_point_embedding, |
|
point_weight, |
|
layer_num = 3, |
|
): |
|
super().__init__() |
|
|
|
self.position_point_embedding = position_point_embedding |
|
self.point_weight = point_weight |
|
for _, p in self.named_parameters(): |
|
p.requires_grad = False |
|
|
|
self.layer_num = layer_num |
|
self.input_image_size = (1024, 1024) |
|
|
|
self.interact_layer = nn.ModuleList() |
|
for _ in range(self.layer_num): |
|
self.interact_layer.append(CrossSelfAttn(embedding_dim=1024, num_heads=4, downsample_rate=4)) |
|
|
|
@torch.no_grad() |
|
def get_bbox_token(self, boxes): |
|
boxes = boxes + 0.5 |
|
coords = boxes.reshape(-1, 2, 2) |
|
corner_embedding = self.position_point_embedding.forward_with_coords(coords, self.input_image_size) |
|
corner_embedding[:, 0, :] += self.point_weight[2].weight |
|
corner_embedding[:, 1, :] += self.point_weight[3].weight |
|
corner_embedding = F.interpolate(corner_embedding[..., None], size=(1024, 1), mode='bilinear', align_corners=False)[..., 0] |
|
return corner_embedding |
|
|
|
@torch.no_grad() |
|
def get_position_embedding(self, size=(64, 64)): |
|
pe = self.position_point_embedding(size=size) |
|
pe = F.interpolate(pe.permute(1, 2, 0)[..., None], size=(1024, 1), mode='bilinear', align_corners=False)[..., 0][None] |
|
pe = pe.reshape(1, -1, 1024) |
|
return pe |
|
|
|
def forward(self, block_feat, bbox, layer_index): |
|
|
|
if layer_index == self.layer_num: |
|
return block_feat |
|
interact_layer = self.interact_layer[layer_index] |
|
|
|
pe = self.get_position_embedding() |
|
bbox_token = self.get_bbox_token(bbox) |
|
|
|
output = interact_layer(block_feat, bbox_token, feat_pe=pe, bbox_pe=bbox_token) |
|
|
|
return output |
|
|
|
class InOutBBoxCrossSelfAttn(nn.Module): |
|
|
|
def __init__(self, embedding_dim=1024, num_heads=4, downsample_rate=4) -> None: |
|
super().__init__() |
|
|
|
self.self_attn = DownAttention(embedding_dim=embedding_dim, num_heads=num_heads, downsample_rate=downsample_rate) |
|
self.norm1 = nn.LayerNorm(embedding_dim) |
|
self.mlp = MLPBlock(embedding_dim, mlp_dim=embedding_dim // 2) |
|
self.norm2 = nn.LayerNorm(embedding_dim) |
|
self.cross_attn = DownAttention(embedding_dim=embedding_dim, num_heads=num_heads, downsample_rate=downsample_rate) |
|
self.norm3 = nn.LayerNorm(embedding_dim) |
|
|
|
def forward(self, in_box_token, out_box_token): |
|
|
|
|
|
short_cut = in_box_token |
|
in_box_token = self.norm1(in_box_token) |
|
in_box_token = self.self_attn(q=in_box_token, k=in_box_token, v=in_box_token) |
|
in_box_token = short_cut + in_box_token |
|
|
|
|
|
in_box_token = in_box_token + self.mlp(self.norm2(in_box_token)) |
|
|
|
|
|
short_cut = in_box_token |
|
in_box_token = self.norm3(in_box_token) |
|
in_box_token = self.cross_attn(q=in_box_token, k=out_box_token, v=out_box_token) |
|
in_box_token = short_cut + in_box_token |
|
|
|
return in_box_token |
|
|
|
|
|
class BBoxInteractInOut(nn.Module): |
|
def __init__( |
|
self, |
|
num_heads = 4, |
|
downsample_rate = 4, |
|
layer_num = 3, |
|
): |
|
super().__init__() |
|
|
|
self.layer_num = layer_num |
|
self.input_image_size = (1024, 1024) |
|
|
|
self.interact_layer = nn.ModuleList() |
|
for _ in range(self.layer_num): |
|
self.interact_layer.append(InOutBBoxCrossSelfAttn(embedding_dim=1024, num_heads=num_heads, downsample_rate=downsample_rate)) |
|
|
|
def forward(self, block_feat, bbox, layer_index): |
|
|
|
|
|
if layer_index == self.layer_num: |
|
return block_feat |
|
interact_layer = self.interact_layer[layer_index] |
|
|
|
|
|
bbox = torch.round(bbox / self.input_image_size[0] * (block_feat.shape[1] - 1)).int() |
|
for i in range(block_feat.shape[0]): |
|
in_bbox_mask = torch.zeros((block_feat.shape[1], block_feat.shape[2]), dtype=bool, device=bbox.device) |
|
in_bbox_mask[bbox[i, 0, 1]: bbox[i, 0, 3], bbox[i, 0, 0]: bbox[i, 0, 2]] = True |
|
in_bbox_token = block_feat[i: i + 1, in_bbox_mask, :] |
|
out_bbox_token = block_feat[i: i + 1, ~in_bbox_mask, :] |
|
block_feat[i, in_bbox_mask, :] = interact_layer(in_bbox_token, out_bbox_token) |
|
|
|
return block_feat |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
|
|
|
|
|
|
emded = BBoxEmbedInteract( |
|
embed_type = 'position', |
|
interact_type = 'cross-self-attn' |
|
) |
|
input = torch.tensor([[[100, 200, 300, 400]], [[100, 200, 300, 400]]]) |
|
print(input.shape) |
|
output = emded(torch.randn((2, 64, 64, 1024)), input) |