|
import torch |
|
|
|
|
|
class TestTimeAugmentation: |
|
"""Test-Time Augmentation for image restoration models""" |
|
|
|
def __init__(self, model, dino_net, device, use_flip=True, use_rot=True, use_multi_scale=False, scales=None): |
|
""" |
|
Args: |
|
model: The model to apply TTA to |
|
dino_net: DINO feature extractor |
|
device: Device to run inference on |
|
use_flip: Whether to use horizontal and vertical flips |
|
use_rot: Whether to use 90-degree rotations |
|
use_multi_scale: Whether to use multi-scale testing |
|
scales: List of scales to use for multi-scale testing, e.g. [0.8, 1.0, 1.2] |
|
""" |
|
self.model = model |
|
self.dino_net = dino_net |
|
self.device = device |
|
self.use_flip = use_flip |
|
self.use_rot = use_rot |
|
self.use_multi_scale = use_multi_scale |
|
self.scales = scales or [1.0] |
|
|
|
def _apply_augmentation(self, image, point, normal, aug_type): |
|
"""Apply single augmentation to input images |
|
|
|
Args: |
|
image: Input RGB image |
|
point: Point map |
|
normal: Normal map |
|
aug_type: Augmentation type string (e.g., 'original', 'h_flip', etc.) |
|
|
|
Returns: |
|
Augmented versions of image, point map and normal map |
|
""" |
|
if aug_type == 'original': |
|
return image, point, normal |
|
|
|
elif aug_type == 'h_flip': |
|
|
|
img_aug = torch.flip(image, dims=[3]) |
|
point_aug = torch.flip(point, dims=[3]) |
|
normal_aug = torch.flip(normal, dims=[3]) |
|
|
|
normal_aug[:, 0, :, :] = -normal_aug[:, 0, :, :] |
|
return img_aug, point_aug, normal_aug |
|
|
|
elif aug_type == 'v_flip': |
|
|
|
img_aug = torch.flip(image, dims=[2]) |
|
point_aug = torch.flip(point, dims=[2]) |
|
normal_aug = torch.flip(normal, dims=[2]) |
|
|
|
normal_aug[:, 1, :, :] = -normal_aug[:, 1, :, :] |
|
return img_aug, point_aug, normal_aug |
|
|
|
elif aug_type == 'rot90': |
|
|
|
img_aug = torch.rot90(image, k=1, dims=[2, 3]) |
|
point_aug = torch.rot90(point, k=1, dims=[2, 3]) |
|
normal_aug = torch.rot90(normal, k=1, dims=[2, 3]) |
|
|
|
normal_x = -normal_aug[:, 1, :, :].clone() |
|
normal_y = normal_aug[:, 0, :, :].clone() |
|
normal_aug[:, 0, :, :] = normal_x |
|
normal_aug[:, 1, :, :] = normal_y |
|
return img_aug, point_aug, normal_aug |
|
|
|
elif aug_type == 'rot180': |
|
|
|
img_aug = torch.rot90(image, k=2, dims=[2, 3]) |
|
point_aug = torch.rot90(point, k=2, dims=[2, 3]) |
|
normal_aug = torch.rot90(normal, k=2, dims=[2, 3]) |
|
|
|
normal_aug[:, 0, :, :] = -normal_aug[:, 0, :, :] |
|
normal_aug[:, 1, :, :] = -normal_aug[:, 1, :, :] |
|
return img_aug, point_aug, normal_aug |
|
|
|
elif aug_type == 'rot270': |
|
|
|
img_aug = torch.rot90(image, k=3, dims=[2, 3]) |
|
point_aug = torch.rot90(point, k=3, dims=[2, 3]) |
|
normal_aug = torch.rot90(normal, k=3, dims=[2, 3]) |
|
|
|
normal_x = normal_aug[:, 1, :, :].clone() |
|
normal_y = -normal_aug[:, 0, :, :].clone() |
|
normal_aug[:, 0, :, :] = normal_x |
|
normal_aug[:, 1, :, :] = normal_y |
|
return img_aug, point_aug, normal_aug |
|
|
|
else: |
|
raise ValueError(f"Unknown augmentation type: {aug_type}") |
|
|
|
def _reverse_augmentation(self, result, aug_type): |
|
"""Reverse the augmentation on the result |
|
|
|
Args: |
|
result: Model output to reverse augmentation on |
|
aug_type: Augmentation type string |
|
|
|
Returns: |
|
De-augmented result |
|
""" |
|
if aug_type == 'original': |
|
return result |
|
|
|
elif aug_type == 'h_flip': |
|
return torch.flip(result, dims=[3]) |
|
|
|
elif aug_type == 'v_flip': |
|
return torch.flip(result, dims=[2]) |
|
|
|
elif aug_type == 'rot90': |
|
return torch.rot90(result, k=3, dims=[2, 3]) |
|
|
|
elif aug_type == 'rot180': |
|
return torch.rot90(result, k=2, dims=[2, 3]) |
|
|
|
elif aug_type == 'rot270': |
|
return torch.rot90(result, k=1, dims=[2, 3]) |
|
|
|
else: |
|
raise ValueError(f"Unknown augmentation type: {aug_type}") |
|
|
|
def __call__(self, sliding_window, input_img, point, normal): |
|
""" |
|
Apply TTA to the model and return ensemble result |
|
|
|
Args: |
|
sliding_window: SlidingWindowInference class instance |
|
input_img: Input RGB image [B, C, H, W] |
|
point: Point map [B, C, H, W] |
|
normal: Normal map [B, C, H, W] |
|
|
|
Returns: |
|
Ensemble result with TTA [B, C, H, W] |
|
""" |
|
|
|
augmentations = ['original'] |
|
if self.use_flip: |
|
augmentations.extend(['h_flip', 'v_flip']) |
|
if self.use_rot: |
|
augmentations.extend(['rot90', 'rot180', 'rot270']) |
|
|
|
|
|
ensemble_result = torch.zeros_like(input_img) |
|
ensemble_weight = 0.0 |
|
|
|
|
|
for scale in self.scales: |
|
scale_weight = 1.0 |
|
if scale != 1.0: |
|
|
|
h, w = input_img.shape[2], input_img.shape[3] |
|
new_h, new_w = int(h * scale), int(w * scale) |
|
|
|
|
|
resize_fn = torch.nn.functional.interpolate |
|
input_img_scaled = resize_fn(input_img, size=(new_h, new_w), mode='bilinear', align_corners=False) |
|
point_scaled = resize_fn(point, size=(new_h, new_w), mode='bilinear', align_corners=False) |
|
normal_scaled = resize_fn(normal, size=(new_h, new_w), mode='bilinear', align_corners=False) |
|
|
|
|
|
normal_norm = torch.sqrt(torch.sum(normal_scaled**2, dim=1, keepdim=True) + 1e-6) |
|
normal_scaled = normal_scaled / normal_norm |
|
else: |
|
input_img_scaled = input_img |
|
point_scaled = point |
|
normal_scaled = normal |
|
|
|
|
|
for aug_type in augmentations: |
|
|
|
img_aug, point_aug, normal_aug = self._apply_augmentation( |
|
input_img_scaled, point_scaled, normal_scaled, aug_type |
|
) |
|
|
|
|
|
with torch.cuda.amp.autocast(): |
|
result_aug = sliding_window( |
|
model=self.model, |
|
input_=img_aug, |
|
point=point_aug, |
|
normal=normal_aug, |
|
dino_net=self.dino_net, |
|
device=self.device |
|
) |
|
|
|
|
|
result_aug = self._reverse_augmentation(result_aug, aug_type) |
|
|
|
|
|
if scale != 1.0: |
|
result_aug = resize_fn(result_aug, size=(h, w), mode='bilinear', align_corners=False) |
|
|
|
|
|
ensemble_result += result_aug * scale_weight |
|
ensemble_weight += scale_weight |
|
|
|
|
|
ensemble_result = ensemble_result / ensemble_weight |
|
|
|
return ensemble_result |