from torchvision.transforms import transforms import torch def process_image(image, shape=(500, 500)): """ This function takes an image and transforms it into a tensor """ transform = transforms.Compose( [ transforms.Resize(shape), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), ] ) image = transform(image).unsqueeze(0) return image def tensor_to_image(tensor): """ This function takes a tensor and transforms it into an image """ inverse_normalize = transforms.Normalize( mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225], std=[1 / 0.229, 1 / 0.224, 1 / 0.225], ) tensor = inverse_normalize(tensor) tensor = torch.clamp(tensor, 0, 1) return transforms.ToPILImage()(tensor)