Fed-AE-Kidney-Stone-Corruption-Detection / simple_anomaly_detector.py
Ivanrs's picture
Upload 15 files
e75d4ed verified
"""
Simple Anomaly Detector using Reconstruction Error
A minimal implementation for testing corruption intensity using autoencoder reconstruction error
"""
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
from typing import Union
import random
from models import Autoencoder
from utils.data_utils import ImageCorruption
import config
def apply_corruption(image_tensor: torch.Tensor, corruption_type: str = 'random') -> torch.Tensor:
"""
Simple function to apply corruption to an image tensor
Args:
image_tensor: Input image tensor (C, H, W)
corruption_type: Type of corruption ('noise', 'blur', 'brightness', 'contrast', 'random')
Returns:
Corrupted image tensor
"""
# Create corruption object with 100% probability to ensure corruption is applied
corruptor = ImageCorruption(corruption_prob=1.0)
if corruption_type == 'noise':
return corruptor.gaussian_noise(image_tensor.clone())
elif corruption_type == 'blur':
return corruptor.blur(image_tensor.clone())
elif corruption_type == 'brightness':
return corruptor.brightness_change(image_tensor.clone())
elif corruption_type == 'contrast':
return corruptor.contrast_change(image_tensor.clone())
elif corruption_type == 'random':
return corruptor.apply_random_corruption(image_tensor.clone())
else:
raise ValueError(f"Unknown corruption type: {corruption_type}")
class SimpleAnomalyDetector:
"""Simple anomaly detector based on reconstruction error"""
def __init__(self, model_path: str):
"""
Initialize the detector with a trained autoencoder
Args:
model_path: Path to the trained autoencoder (.pth file)
"""
self.device = torch.device(config.DEVICE)
self.model = self._load_model(model_path)
self.criterion = nn.MSELoss()
# Image preprocessing - simplified and more robust
self.transform = transforms.Compose([
transforms.Resize((config.IMAGE_SIZE, config.IMAGE_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
print(f"βœ… Anomaly detector ready! Using device: {self.device}")
print(f"πŸ“ Image size: {config.IMAGE_SIZE}x{config.IMAGE_SIZE}")
def _load_model(self, model_path: str) -> Autoencoder:
"""Load the trained autoencoder model"""
print(f"πŸ“₯ Loading model from {model_path}")
# Load checkpoint (weights_only=False for compatibility with saved metadata)
checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
# Create model with same architecture
model = Autoencoder(
input_channels=config.CHANNELS,
latent_dim=config.LATENT_DIM
)
# Load trained weights
model.load_state_dict(checkpoint['model_state_dict'])
model.to(self.device)
model.eval()
return model
def calculate_reconstruction_error(self, image: Union[str, Image.Image, torch.Tensor]) -> float:
"""
Calculate reconstruction error for a single image
Args:
image: Can be:
- String path to image file
- PIL Image object
- PyTorch tensor (C, H, W) or (1, C, H, W)
Returns:
Reconstruction error as a float (higher = more anomalous)
"""
# Get image size - handle both tuple and integer formats
if isinstance(config.IMAGE_SIZE, tuple):
target_size = config.IMAGE_SIZE # (256, 256)
else:
target_size = (config.IMAGE_SIZE, config.IMAGE_SIZE)
# Convert input to tensor
if isinstance(image, str):
# Load from file path
try:
image_pil = Image.open(image).convert('RGB')
# Resize the image properly
image_pil = image_pil.resize(target_size, Image.LANCZOS)
image_tensor = transforms.ToTensor()(image_pil)
# Apply normalization
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
image_tensor = normalize(image_tensor).unsqueeze(0) # Add batch dimension
except Exception as e:
raise ValueError(f"Error loading image from {image}: {e}")
elif isinstance(image, Image.Image):
# PIL Image
try:
image_pil = image.convert('RGB')
image_pil = image_pil.resize(target_size, Image.LANCZOS)
image_tensor = transforms.ToTensor()(image_pil)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
image_tensor = normalize(image_tensor).unsqueeze(0)
except Exception as e:
raise ValueError(f"Error processing PIL Image: {e}")
elif isinstance(image, torch.Tensor):
# PyTorch tensor
if image.dim() == 3: # (C, H, W)
image_tensor = image.unsqueeze(0) # Add batch dimension
elif image.dim() == 4: # (1, C, H, W)
image_tensor = image
else:
raise ValueError(f"Unexpected tensor dimensions: {image.shape}")
else:
raise ValueError(f"Unsupported image type: {type(image)}")
# Move to device
image_tensor = image_tensor.to(self.device)
# Calculate reconstruction error
with torch.no_grad():
reconstructed, _ = self.model(image_tensor)
error = self.criterion(reconstructed, image_tensor)
return error.item()
def test_detector_example():
"""Example usage of the simple anomaly detector"""
# You need to specify the path to your trained model
model_path = "models/All_Datasets_MIX/best_autoencoder_All_Datasets_MIX.pth" # Change this!
try:
# Initialize detector
detector = SimpleAnomalyDetector(model_path)
# Test with some images from your dataset
from utils.data_utils import create_global_test_loader
# Get a test loader
test_loader = create_global_test_loader(
datasets=["Michel Daudon (w256 1k v1)", "Jonathan El-Beze (w256 1k v1)"],
subversions=["MIX"]
)
print("\nπŸ§ͺ Testing reconstruction errors:")
print("=" * 50)
# Test a few images
for i, (images, labels) in enumerate(test_loader):
if i >= 3: # Test only first 3 batches
break
for j in range(min(2, images.size(0))): # Test 2 images per batch
clean_image = images[j]
# Test clean image
clean_error = detector.calculate_reconstruction_error(clean_image)
# Test corrupted versions
corrupted_noise = apply_corruption(clean_image, 'noise')
corrupted_blur = apply_corruption(clean_image, 'blur')
noise_error = detector.calculate_reconstruction_error(corrupted_noise)
blur_error = detector.calculate_reconstruction_error(corrupted_blur)
print(f"\nImage {i*2 + j + 1} (Class: {labels[j]}):")
print(f" Clean: {clean_error:.6f}")
print(f" Noise corrupted: {noise_error:.6f} (x{noise_error/clean_error:.2f})")
print(f" Blur corrupted: {blur_error:.6f} (x{blur_error/clean_error:.2f})")
print(f"\nπŸ’‘ Usage tip: Higher reconstruction error = more anomalous/corrupted")
print(f" You can set a threshold (e.g., 0.01) above which images are considered anomalous")
except FileNotFoundError:
print(f"❌ Model file not found: {model_path}")
print(" Please update the model_path variable with your actual model file")
except Exception as e:
print(f"❌ Error: {e}")
if __name__ == "__main__":
test_detector_example()