File size: 969 Bytes
8b06175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from torchvision.transforms import transforms
import torch


def process_image(image, shape=(500, 500)):
    """
    This function takes an image and transforms it into a tensor

    """
    transform = transforms.Compose(
                [
                    transforms.Resize(shape),
                    transforms.ToTensor(),
                    transforms.Normalize(
                        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                    ),
                ]
            )
    image = transform(image).unsqueeze(0)
    return image

def tensor_to_image(tensor):
    """
    This function takes a tensor and transforms it into an image
    """
    inverse_normalize = transforms.Normalize(
            mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225],
            std=[1 / 0.229, 1 / 0.224, 1 / 0.225],
        )
    tensor = inverse_normalize(tensor)
    tensor = torch.clamp(tensor, 0, 1)
    return transforms.ToPILImage()(tensor)