|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch.nn as nn |
|
from monai.networks.blocks import Warp |
|
from monai.networks.nets import resnet18 |
|
from monai.networks.nets.regunet import AffineHead |
|
|
|
|
|
class RegResNet(nn.Module): |
|
def __init__( |
|
self, |
|
image_size=(64, 64), |
|
spatial_dims=2, |
|
mod=None, |
|
mode="bilinear", |
|
padding_mode="border", |
|
features=400, |
|
): |
|
super().__init__() |
|
self.features = resnet18(n_input_channels=2, spatial_dims=spatial_dims) if mod is None else mod |
|
self.affine_head = AffineHead( |
|
spatial_dims=spatial_dims, image_size=image_size, decode_size=[1] * spatial_dims, in_channels=features |
|
) |
|
self.warp = Warp(mode=mode, padding_mode=padding_mode) |
|
self.image_size = image_size |
|
|
|
def forward(self, x): |
|
self.features.to(device=x.device) |
|
self.affine_head.to(device=x.device) |
|
out = self.features(x) |
|
ddf = self.affine_head([out], self.image_size) |
|
f = self.warp(x[:, :1], ddf) |
|
return f |
|
|