DenseSR / utils /dataset_utils.py
VanLinLin's picture
birth
3f7c489
import torch
import os
import torchvision.transforms as transforms
class Augment_RGB_torch:
### rotate and flip
def __init__(self, rotate=0):
self.rotate = rotate
pass
def transform0(self, torch_tensor):
return torch_tensor
def transform1(self, torch_tensor):
H, W = torch_tensor.shape[1], torch_tensor.shape[2]
train_transform = transforms.Compose([
transforms.RandomRotation((self.rotate,self.rotate), interpolation=transforms.InterpolationMode.BILINEAR, expand=False),
transforms.Resize((int(H * 1.3), int(W * 1.3)), antialias=True),
# CenterCrop,if the size is larger than the original size, the excess will be filled with black
transforms.CenterCrop([H, W])
])
return train_transform(torch_tensor)
def transform2(self, torch_tensor):
torch_tensor = torch.rot90(torch_tensor, k=1, dims=[-1,-2])
return torch_tensor
def transform3(self, torch_tensor):
torch_tensor = torch.rot90(torch_tensor, k=2, dims=[-1,-2])
return torch_tensor
def transform4(self, torch_tensor):
torch_tensor = torch.rot90(torch_tensor, k=3, dims=[-1,-2])
return torch_tensor
def transform5(self, torch_tensor):
torch_tensor = torch_tensor.flip(-2)
return torch_tensor
def transform6(self, torch_tensor):
torch_tensor = (torch.rot90(torch_tensor, k=1, dims=[-1,-2])).flip(-2)
return torch_tensor
def transform7(self, torch_tensor):
torch_tensor = (torch.rot90(torch_tensor, k=2, dims=[-1,-2])).flip(-2)
return torch_tensor
def transform8(self, torch_tensor):
torch_tensor = (torch.rot90(torch_tensor, k=3, dims=[-1,-2])).flip(-2)
return torch_tensor