|
from torchvision.transforms import transforms |
|
import torch |
|
|
|
|
|
def process_image(image, shape=(500, 500)): |
|
""" |
|
This function takes an image and transforms it into a tensor |
|
|
|
""" |
|
transform = transforms.Compose( |
|
[ |
|
transforms.Resize(shape), |
|
transforms.ToTensor(), |
|
transforms.Normalize( |
|
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] |
|
), |
|
] |
|
) |
|
image = transform(image).unsqueeze(0) |
|
return image |
|
|
|
def tensor_to_image(tensor): |
|
""" |
|
This function takes a tensor and transforms it into an image |
|
""" |
|
inverse_normalize = transforms.Normalize( |
|
mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225], |
|
std=[1 / 0.229, 1 / 0.224, 1 / 0.225], |
|
) |
|
tensor = inverse_normalize(tensor) |
|
tensor = torch.clamp(tensor, 0, 1) |
|
return transforms.ToPILImage()(tensor) |