DenseSR / utils /tta.py
VanLinLin's picture
birth
3f7c489
import torch
class TestTimeAugmentation:
"""Test-Time Augmentation for image restoration models"""
def __init__(self, model, dino_net, device, use_flip=True, use_rot=True, use_multi_scale=False, scales=None):
"""
Args:
model: The model to apply TTA to
dino_net: DINO feature extractor
device: Device to run inference on
use_flip: Whether to use horizontal and vertical flips
use_rot: Whether to use 90-degree rotations
use_multi_scale: Whether to use multi-scale testing
scales: List of scales to use for multi-scale testing, e.g. [0.8, 1.0, 1.2]
"""
self.model = model
self.dino_net = dino_net
self.device = device
self.use_flip = use_flip
self.use_rot = use_rot
self.use_multi_scale = use_multi_scale
self.scales = scales or [1.0]
def _apply_augmentation(self, image, point, normal, aug_type):
"""Apply single augmentation to input images
Args:
image: Input RGB image
point: Point map
normal: Normal map
aug_type: Augmentation type string (e.g., 'original', 'h_flip', etc.)
Returns:
Augmented versions of image, point map and normal map
"""
if aug_type == 'original':
return image, point, normal
elif aug_type == 'h_flip':
# Horizontal flip
img_aug = torch.flip(image, dims=[3])
point_aug = torch.flip(point, dims=[3])
normal_aug = torch.flip(normal, dims=[3])
# For normal map, x direction needs to be flipped
normal_aug[:, 0, :, :] = -normal_aug[:, 0, :, :]
return img_aug, point_aug, normal_aug
elif aug_type == 'v_flip':
# Vertical flip
img_aug = torch.flip(image, dims=[2])
point_aug = torch.flip(point, dims=[2])
normal_aug = torch.flip(normal, dims=[2])
# For normal map, y direction needs to be flipped
normal_aug[:, 1, :, :] = -normal_aug[:, 1, :, :]
return img_aug, point_aug, normal_aug
elif aug_type == 'rot90':
# 90-degree rotation
img_aug = torch.rot90(image, k=1, dims=[2, 3])
point_aug = torch.rot90(point, k=1, dims=[2, 3])
normal_aug = torch.rot90(normal, k=1, dims=[2, 3])
# Swap x and y channels in normal map and negate x
normal_x = -normal_aug[:, 1, :, :].clone()
normal_y = normal_aug[:, 0, :, :].clone()
normal_aug[:, 0, :, :] = normal_x
normal_aug[:, 1, :, :] = normal_y
return img_aug, point_aug, normal_aug
elif aug_type == 'rot180':
# 180-degree rotation
img_aug = torch.rot90(image, k=2, dims=[2, 3])
point_aug = torch.rot90(point, k=2, dims=[2, 3])
normal_aug = torch.rot90(normal, k=2, dims=[2, 3])
# For normal map, both x and y directions need to be flipped
normal_aug[:, 0, :, :] = -normal_aug[:, 0, :, :]
normal_aug[:, 1, :, :] = -normal_aug[:, 1, :, :]
return img_aug, point_aug, normal_aug
elif aug_type == 'rot270':
# 270-degree rotation
img_aug = torch.rot90(image, k=3, dims=[2, 3])
point_aug = torch.rot90(point, k=3, dims=[2, 3])
normal_aug = torch.rot90(normal, k=3, dims=[2, 3])
# Swap x and y channels in normal map and negate y
normal_x = normal_aug[:, 1, :, :].clone()
normal_y = -normal_aug[:, 0, :, :].clone()
normal_aug[:, 0, :, :] = normal_x
normal_aug[:, 1, :, :] = normal_y
return img_aug, point_aug, normal_aug
else:
raise ValueError(f"Unknown augmentation type: {aug_type}")
def _reverse_augmentation(self, result, aug_type):
"""Reverse the augmentation on the result
Args:
result: Model output to reverse augmentation on
aug_type: Augmentation type string
Returns:
De-augmented result
"""
if aug_type == 'original':
return result
elif aug_type == 'h_flip':
return torch.flip(result, dims=[3])
elif aug_type == 'v_flip':
return torch.flip(result, dims=[2])
elif aug_type == 'rot90':
return torch.rot90(result, k=3, dims=[2, 3])
elif aug_type == 'rot180':
return torch.rot90(result, k=2, dims=[2, 3])
elif aug_type == 'rot270':
return torch.rot90(result, k=1, dims=[2, 3])
else:
raise ValueError(f"Unknown augmentation type: {aug_type}")
def __call__(self, sliding_window, input_img, point, normal):
"""
Apply TTA to the model and return ensemble result
Args:
sliding_window: SlidingWindowInference class instance
input_img: Input RGB image [B, C, H, W]
point: Point map [B, C, H, W]
normal: Normal map [B, C, H, W]
Returns:
Ensemble result with TTA [B, C, H, W]
"""
# Define all augmentations to use
augmentations = ['original']
if self.use_flip:
augmentations.extend(['h_flip', 'v_flip'])
if self.use_rot:
augmentations.extend(['rot90', 'rot180', 'rot270'])
# Initialize the result tensor
ensemble_result = torch.zeros_like(input_img)
ensemble_weight = 0.0
# For each scale and augmentation
for scale in self.scales:
scale_weight = 1.0
if scale != 1.0:
# Resize inputs for multi-scale testing
h, w = input_img.shape[2], input_img.shape[3]
new_h, new_w = int(h * scale), int(w * scale)
# Resize all inputs
resize_fn = torch.nn.functional.interpolate
input_img_scaled = resize_fn(input_img, size=(new_h, new_w), mode='bilinear', align_corners=False)
point_scaled = resize_fn(point, size=(new_h, new_w), mode='bilinear', align_corners=False)
normal_scaled = resize_fn(normal, size=(new_h, new_w), mode='bilinear', align_corners=False)
# Normalize normal vectors after resizing
normal_norm = torch.sqrt(torch.sum(normal_scaled**2, dim=1, keepdim=True) + 1e-6)
normal_scaled = normal_scaled / normal_norm
else:
input_img_scaled = input_img
point_scaled = point
normal_scaled = normal
# Apply each augmentation
for aug_type in augmentations:
# Apply augmentation
img_aug, point_aug, normal_aug = self._apply_augmentation(
input_img_scaled, point_scaled, normal_scaled, aug_type
)
# Run model inference with sliding window
with torch.cuda.amp.autocast():
result_aug = sliding_window(
model=self.model,
input_=img_aug,
point=point_aug,
normal=normal_aug,
dino_net=self.dino_net,
device=self.device
)
# Reverse augmentation on the result
result_aug = self._reverse_augmentation(result_aug, aug_type)
# Resize back to original size if using multi-scale
if scale != 1.0:
result_aug = resize_fn(result_aug, size=(h, w), mode='bilinear', align_corners=False)
# Add to ensemble
ensemble_result += result_aug * scale_weight
ensemble_weight += scale_weight
# Average results
ensemble_result = ensemble_result / ensemble_weight
return ensemble_result