|
"""
|
|
Simple Anomaly Detector using Reconstruction Error
|
|
A minimal implementation for testing corruption intensity using autoencoder reconstruction error
|
|
"""
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
from PIL import Image
|
|
import torchvision.transforms as transforms
|
|
from typing import Union
|
|
import random
|
|
|
|
from models import Autoencoder
|
|
from utils.data_utils import ImageCorruption
|
|
import config
|
|
|
|
|
|
def apply_corruption(image_tensor: torch.Tensor, corruption_type: str = 'random') -> torch.Tensor:
|
|
"""
|
|
Simple function to apply corruption to an image tensor
|
|
|
|
Args:
|
|
image_tensor: Input image tensor (C, H, W)
|
|
corruption_type: Type of corruption ('noise', 'blur', 'brightness', 'contrast', 'random')
|
|
|
|
Returns:
|
|
Corrupted image tensor
|
|
"""
|
|
|
|
corruptor = ImageCorruption(corruption_prob=1.0)
|
|
|
|
if corruption_type == 'noise':
|
|
return corruptor.gaussian_noise(image_tensor.clone())
|
|
elif corruption_type == 'blur':
|
|
return corruptor.blur(image_tensor.clone())
|
|
elif corruption_type == 'brightness':
|
|
return corruptor.brightness_change(image_tensor.clone())
|
|
elif corruption_type == 'contrast':
|
|
return corruptor.contrast_change(image_tensor.clone())
|
|
elif corruption_type == 'random':
|
|
return corruptor.apply_random_corruption(image_tensor.clone())
|
|
else:
|
|
raise ValueError(f"Unknown corruption type: {corruption_type}")
|
|
|
|
|
|
class SimpleAnomalyDetector:
|
|
"""Simple anomaly detector based on reconstruction error"""
|
|
|
|
def __init__(self, model_path: str):
|
|
"""
|
|
Initialize the detector with a trained autoencoder
|
|
|
|
Args:
|
|
model_path: Path to the trained autoencoder (.pth file)
|
|
"""
|
|
self.device = torch.device(config.DEVICE)
|
|
self.model = self._load_model(model_path)
|
|
self.criterion = nn.MSELoss()
|
|
|
|
|
|
self.transform = transforms.Compose([
|
|
transforms.Resize((config.IMAGE_SIZE, config.IMAGE_SIZE)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
|
std=[0.229, 0.224, 0.225])
|
|
])
|
|
|
|
print(f"β
Anomaly detector ready! Using device: {self.device}")
|
|
print(f"π Image size: {config.IMAGE_SIZE}x{config.IMAGE_SIZE}")
|
|
|
|
def _load_model(self, model_path: str) -> Autoencoder:
|
|
"""Load the trained autoencoder model"""
|
|
print(f"π₯ Loading model from {model_path}")
|
|
|
|
|
|
checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
|
|
|
|
|
|
model = Autoencoder(
|
|
input_channels=config.CHANNELS,
|
|
latent_dim=config.LATENT_DIM
|
|
)
|
|
|
|
|
|
model.load_state_dict(checkpoint['model_state_dict'])
|
|
model.to(self.device)
|
|
model.eval()
|
|
|
|
return model
|
|
|
|
def calculate_reconstruction_error(self, image: Union[str, Image.Image, torch.Tensor]) -> float:
|
|
"""
|
|
Calculate reconstruction error for a single image
|
|
|
|
Args:
|
|
image: Can be:
|
|
- String path to image file
|
|
- PIL Image object
|
|
- PyTorch tensor (C, H, W) or (1, C, H, W)
|
|
|
|
Returns:
|
|
Reconstruction error as a float (higher = more anomalous)
|
|
"""
|
|
|
|
if isinstance(config.IMAGE_SIZE, tuple):
|
|
target_size = config.IMAGE_SIZE
|
|
else:
|
|
target_size = (config.IMAGE_SIZE, config.IMAGE_SIZE)
|
|
|
|
|
|
if isinstance(image, str):
|
|
|
|
try:
|
|
image_pil = Image.open(image).convert('RGB')
|
|
|
|
image_pil = image_pil.resize(target_size, Image.LANCZOS)
|
|
image_tensor = transforms.ToTensor()(image_pil)
|
|
|
|
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
image_tensor = normalize(image_tensor).unsqueeze(0)
|
|
except Exception as e:
|
|
raise ValueError(f"Error loading image from {image}: {e}")
|
|
|
|
elif isinstance(image, Image.Image):
|
|
|
|
try:
|
|
image_pil = image.convert('RGB')
|
|
image_pil = image_pil.resize(target_size, Image.LANCZOS)
|
|
image_tensor = transforms.ToTensor()(image_pil)
|
|
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
image_tensor = normalize(image_tensor).unsqueeze(0)
|
|
except Exception as e:
|
|
raise ValueError(f"Error processing PIL Image: {e}")
|
|
|
|
elif isinstance(image, torch.Tensor):
|
|
|
|
if image.dim() == 3:
|
|
image_tensor = image.unsqueeze(0)
|
|
elif image.dim() == 4:
|
|
image_tensor = image
|
|
else:
|
|
raise ValueError(f"Unexpected tensor dimensions: {image.shape}")
|
|
else:
|
|
raise ValueError(f"Unsupported image type: {type(image)}")
|
|
|
|
|
|
image_tensor = image_tensor.to(self.device)
|
|
|
|
|
|
with torch.no_grad():
|
|
reconstructed, _ = self.model(image_tensor)
|
|
error = self.criterion(reconstructed, image_tensor)
|
|
|
|
return error.item()
|
|
|
|
|
|
def test_detector_example():
|
|
"""Example usage of the simple anomaly detector"""
|
|
|
|
|
|
model_path = "models/All_Datasets_MIX/best_autoencoder_All_Datasets_MIX.pth"
|
|
|
|
try:
|
|
|
|
detector = SimpleAnomalyDetector(model_path)
|
|
|
|
|
|
from utils.data_utils import create_global_test_loader
|
|
|
|
|
|
test_loader = create_global_test_loader(
|
|
datasets=["Michel Daudon (w256 1k v1)", "Jonathan El-Beze (w256 1k v1)"],
|
|
subversions=["MIX"]
|
|
)
|
|
|
|
print("\nπ§ͺ Testing reconstruction errors:")
|
|
print("=" * 50)
|
|
|
|
|
|
for i, (images, labels) in enumerate(test_loader):
|
|
if i >= 3:
|
|
break
|
|
|
|
for j in range(min(2, images.size(0))):
|
|
clean_image = images[j]
|
|
|
|
|
|
clean_error = detector.calculate_reconstruction_error(clean_image)
|
|
|
|
|
|
corrupted_noise = apply_corruption(clean_image, 'noise')
|
|
corrupted_blur = apply_corruption(clean_image, 'blur')
|
|
|
|
noise_error = detector.calculate_reconstruction_error(corrupted_noise)
|
|
blur_error = detector.calculate_reconstruction_error(corrupted_blur)
|
|
|
|
print(f"\nImage {i*2 + j + 1} (Class: {labels[j]}):")
|
|
print(f" Clean: {clean_error:.6f}")
|
|
print(f" Noise corrupted: {noise_error:.6f} (x{noise_error/clean_error:.2f})")
|
|
print(f" Blur corrupted: {blur_error:.6f} (x{blur_error/clean_error:.2f})")
|
|
|
|
print(f"\nπ‘ Usage tip: Higher reconstruction error = more anomalous/corrupted")
|
|
print(f" You can set a threshold (e.g., 0.01) above which images are considered anomalous")
|
|
|
|
except FileNotFoundError:
|
|
print(f"β Model file not found: {model_path}")
|
|
print(" Please update the model_path variable with your actual model file")
|
|
except Exception as e:
|
|
print(f"β Error: {e}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_detector_example() |