Upload 15 files
Browse files- .gitattributes +5 -0
- example_imgs/72222-SectionIVa+WK maj_0009-60.png +0 -0
- example_imgs/Section_Va_72845-3-18.png +3 -0
- example_imgs/TypeIVa2_N47583_Notteb-11.png +3 -0
- example_imgs/TypeIVd_Sect_LC3373-65.png +3 -0
- example_imgs/TypeIa_LaosN°15_Image21-25.png +3 -0
- example_imgs/typIVc_IVbsectbis-43.png +3 -0
- models/Daudon_MIX/best_autoencoder_Daudon_MIX.pth +3 -0
- models/Daudon_SEC/best_autoencoder_Daudon_SEC.pth +3 -0
- models/Daudon_SUR/best_autoencoder_Daudon_SUR.pth +3 -0
- requirements.txt +8 -0
- simple_anomaly_detector.py +214 -0
- simple_gradio_app.py +357 -0
- utils/__init__.py +13 -0
- utils/data_utils.py +416 -0
- utils/metrics.py +316 -0
.gitattributes
CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
example_imgs/Section_Va_72845-3-18.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
example_imgs/TypeIa_LaosN°15_Image21-25.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
example_imgs/TypeIVa2_N47583_Notteb-11.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
example_imgs/TypeIVd_Sect_LC3373-65.png filter=lfs diff=lfs merge=lfs -text
|
40 |
+
example_imgs/typIVc_IVbsectbis-43.png filter=lfs diff=lfs merge=lfs -text
|
example_imgs/72222-SectionIVa+WK maj_0009-60.png
ADDED
![]() |
example_imgs/Section_Va_72845-3-18.png
ADDED
![]() |
Git LFS Details
|
example_imgs/TypeIVa2_N47583_Notteb-11.png
ADDED
![]() |
Git LFS Details
|
example_imgs/TypeIVd_Sect_LC3373-65.png
ADDED
![]() |
Git LFS Details
|
example_imgs/TypeIa_LaosN°15_Image21-25.png
ADDED
![]() |
Git LFS Details
|
example_imgs/typIVc_IVbsectbis-43.png
ADDED
![]() |
Git LFS Details
|
models/Daudon_MIX/best_autoencoder_Daudon_MIX.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8b63d58b9d42c9de2b912d2cbd770c1b1b93c591965e44e6d1479ee0caa01007
|
3 |
+
size 123579888
|
models/Daudon_SEC/best_autoencoder_Daudon_SEC.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:687f67252659ffbdde997adb689b24d80b8f77a4baedca2d52cdadbd4d30fa65
|
3 |
+
size 123579888
|
models/Daudon_SUR/best_autoencoder_Daudon_SUR.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e969a304b220177169ff5e46668110cb62fdd4752ffdf4ecb84bbf50d68bff61
|
3 |
+
size 123579888
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=2.0.0
|
2 |
+
torchvision>=0.15.0
|
3 |
+
flwr>=1.6.0
|
4 |
+
numpy>=1.24.0
|
5 |
+
Pillow>=9.5.0
|
6 |
+
matplotlib>=3.7.0
|
7 |
+
scikit-learn>=1.3.0
|
8 |
+
tqdm>=4.65.0
|
simple_anomaly_detector.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Simple Anomaly Detector using Reconstruction Error
|
3 |
+
A minimal implementation for testing corruption intensity using autoencoder reconstruction error
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import numpy as np
|
9 |
+
from PIL import Image
|
10 |
+
import torchvision.transforms as transforms
|
11 |
+
from typing import Union
|
12 |
+
import random
|
13 |
+
|
14 |
+
from models import Autoencoder
|
15 |
+
from utils.data_utils import ImageCorruption
|
16 |
+
import config
|
17 |
+
|
18 |
+
|
19 |
+
def apply_corruption(image_tensor: torch.Tensor, corruption_type: str = 'random') -> torch.Tensor:
|
20 |
+
"""
|
21 |
+
Simple function to apply corruption to an image tensor
|
22 |
+
|
23 |
+
Args:
|
24 |
+
image_tensor: Input image tensor (C, H, W)
|
25 |
+
corruption_type: Type of corruption ('noise', 'blur', 'brightness', 'contrast', 'random')
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
Corrupted image tensor
|
29 |
+
"""
|
30 |
+
# Create corruption object with 100% probability to ensure corruption is applied
|
31 |
+
corruptor = ImageCorruption(corruption_prob=1.0)
|
32 |
+
|
33 |
+
if corruption_type == 'noise':
|
34 |
+
return corruptor.gaussian_noise(image_tensor.clone())
|
35 |
+
elif corruption_type == 'blur':
|
36 |
+
return corruptor.blur(image_tensor.clone())
|
37 |
+
elif corruption_type == 'brightness':
|
38 |
+
return corruptor.brightness_change(image_tensor.clone())
|
39 |
+
elif corruption_type == 'contrast':
|
40 |
+
return corruptor.contrast_change(image_tensor.clone())
|
41 |
+
elif corruption_type == 'random':
|
42 |
+
return corruptor.apply_random_corruption(image_tensor.clone())
|
43 |
+
else:
|
44 |
+
raise ValueError(f"Unknown corruption type: {corruption_type}")
|
45 |
+
|
46 |
+
|
47 |
+
class SimpleAnomalyDetector:
|
48 |
+
"""Simple anomaly detector based on reconstruction error"""
|
49 |
+
|
50 |
+
def __init__(self, model_path: str):
|
51 |
+
"""
|
52 |
+
Initialize the detector with a trained autoencoder
|
53 |
+
|
54 |
+
Args:
|
55 |
+
model_path: Path to the trained autoencoder (.pth file)
|
56 |
+
"""
|
57 |
+
self.device = torch.device(config.DEVICE)
|
58 |
+
self.model = self._load_model(model_path)
|
59 |
+
self.criterion = nn.MSELoss()
|
60 |
+
|
61 |
+
# Image preprocessing - simplified and more robust
|
62 |
+
self.transform = transforms.Compose([
|
63 |
+
transforms.Resize((config.IMAGE_SIZE, config.IMAGE_SIZE)),
|
64 |
+
transforms.ToTensor(),
|
65 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
66 |
+
std=[0.229, 0.224, 0.225])
|
67 |
+
])
|
68 |
+
|
69 |
+
print(f"✅ Anomaly detector ready! Using device: {self.device}")
|
70 |
+
print(f"📏 Image size: {config.IMAGE_SIZE}x{config.IMAGE_SIZE}")
|
71 |
+
|
72 |
+
def _load_model(self, model_path: str) -> Autoencoder:
|
73 |
+
"""Load the trained autoencoder model"""
|
74 |
+
print(f"📥 Loading model from {model_path}")
|
75 |
+
|
76 |
+
# Load checkpoint (weights_only=False for compatibility with saved metadata)
|
77 |
+
checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
|
78 |
+
|
79 |
+
# Create model with same architecture
|
80 |
+
model = Autoencoder(
|
81 |
+
input_channels=config.CHANNELS,
|
82 |
+
latent_dim=config.LATENT_DIM
|
83 |
+
)
|
84 |
+
|
85 |
+
# Load trained weights
|
86 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
87 |
+
model.to(self.device)
|
88 |
+
model.eval()
|
89 |
+
|
90 |
+
return model
|
91 |
+
|
92 |
+
def calculate_reconstruction_error(self, image: Union[str, Image.Image, torch.Tensor]) -> float:
|
93 |
+
"""
|
94 |
+
Calculate reconstruction error for a single image
|
95 |
+
|
96 |
+
Args:
|
97 |
+
image: Can be:
|
98 |
+
- String path to image file
|
99 |
+
- PIL Image object
|
100 |
+
- PyTorch tensor (C, H, W) or (1, C, H, W)
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
Reconstruction error as a float (higher = more anomalous)
|
104 |
+
"""
|
105 |
+
# Get image size - handle both tuple and integer formats
|
106 |
+
if isinstance(config.IMAGE_SIZE, tuple):
|
107 |
+
target_size = config.IMAGE_SIZE # (256, 256)
|
108 |
+
else:
|
109 |
+
target_size = (config.IMAGE_SIZE, config.IMAGE_SIZE)
|
110 |
+
|
111 |
+
# Convert input to tensor
|
112 |
+
if isinstance(image, str):
|
113 |
+
# Load from file path
|
114 |
+
try:
|
115 |
+
image_pil = Image.open(image).convert('RGB')
|
116 |
+
# Resize the image properly
|
117 |
+
image_pil = image_pil.resize(target_size, Image.LANCZOS)
|
118 |
+
image_tensor = transforms.ToTensor()(image_pil)
|
119 |
+
# Apply normalization
|
120 |
+
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
121 |
+
image_tensor = normalize(image_tensor).unsqueeze(0) # Add batch dimension
|
122 |
+
except Exception as e:
|
123 |
+
raise ValueError(f"Error loading image from {image}: {e}")
|
124 |
+
|
125 |
+
elif isinstance(image, Image.Image):
|
126 |
+
# PIL Image
|
127 |
+
try:
|
128 |
+
image_pil = image.convert('RGB')
|
129 |
+
image_pil = image_pil.resize(target_size, Image.LANCZOS)
|
130 |
+
image_tensor = transforms.ToTensor()(image_pil)
|
131 |
+
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
132 |
+
image_tensor = normalize(image_tensor).unsqueeze(0)
|
133 |
+
except Exception as e:
|
134 |
+
raise ValueError(f"Error processing PIL Image: {e}")
|
135 |
+
|
136 |
+
elif isinstance(image, torch.Tensor):
|
137 |
+
# PyTorch tensor
|
138 |
+
if image.dim() == 3: # (C, H, W)
|
139 |
+
image_tensor = image.unsqueeze(0) # Add batch dimension
|
140 |
+
elif image.dim() == 4: # (1, C, H, W)
|
141 |
+
image_tensor = image
|
142 |
+
else:
|
143 |
+
raise ValueError(f"Unexpected tensor dimensions: {image.shape}")
|
144 |
+
else:
|
145 |
+
raise ValueError(f"Unsupported image type: {type(image)}")
|
146 |
+
|
147 |
+
# Move to device
|
148 |
+
image_tensor = image_tensor.to(self.device)
|
149 |
+
|
150 |
+
# Calculate reconstruction error
|
151 |
+
with torch.no_grad():
|
152 |
+
reconstructed, _ = self.model(image_tensor)
|
153 |
+
error = self.criterion(reconstructed, image_tensor)
|
154 |
+
|
155 |
+
return error.item()
|
156 |
+
|
157 |
+
|
158 |
+
def test_detector_example():
|
159 |
+
"""Example usage of the simple anomaly detector"""
|
160 |
+
|
161 |
+
# You need to specify the path to your trained model
|
162 |
+
model_path = "models/All_Datasets_MIX/best_autoencoder_All_Datasets_MIX.pth" # Change this!
|
163 |
+
|
164 |
+
try:
|
165 |
+
# Initialize detector
|
166 |
+
detector = SimpleAnomalyDetector(model_path)
|
167 |
+
|
168 |
+
# Test with some images from your dataset
|
169 |
+
from utils.data_utils import create_global_test_loader
|
170 |
+
|
171 |
+
# Get a test loader
|
172 |
+
test_loader = create_global_test_loader(
|
173 |
+
datasets=["Michel Daudon (w256 1k v1)", "Jonathan El-Beze (w256 1k v1)"],
|
174 |
+
subversions=["MIX"]
|
175 |
+
)
|
176 |
+
|
177 |
+
print("\n🧪 Testing reconstruction errors:")
|
178 |
+
print("=" * 50)
|
179 |
+
|
180 |
+
# Test a few images
|
181 |
+
for i, (images, labels) in enumerate(test_loader):
|
182 |
+
if i >= 3: # Test only first 3 batches
|
183 |
+
break
|
184 |
+
|
185 |
+
for j in range(min(2, images.size(0))): # Test 2 images per batch
|
186 |
+
clean_image = images[j]
|
187 |
+
|
188 |
+
# Test clean image
|
189 |
+
clean_error = detector.calculate_reconstruction_error(clean_image)
|
190 |
+
|
191 |
+
# Test corrupted versions
|
192 |
+
corrupted_noise = apply_corruption(clean_image, 'noise')
|
193 |
+
corrupted_blur = apply_corruption(clean_image, 'blur')
|
194 |
+
|
195 |
+
noise_error = detector.calculate_reconstruction_error(corrupted_noise)
|
196 |
+
blur_error = detector.calculate_reconstruction_error(corrupted_blur)
|
197 |
+
|
198 |
+
print(f"\nImage {i*2 + j + 1} (Class: {labels[j]}):")
|
199 |
+
print(f" Clean: {clean_error:.6f}")
|
200 |
+
print(f" Noise corrupted: {noise_error:.6f} (x{noise_error/clean_error:.2f})")
|
201 |
+
print(f" Blur corrupted: {blur_error:.6f} (x{blur_error/clean_error:.2f})")
|
202 |
+
|
203 |
+
print(f"\n💡 Usage tip: Higher reconstruction error = more anomalous/corrupted")
|
204 |
+
print(f" You can set a threshold (e.g., 0.01) above which images are considered anomalous")
|
205 |
+
|
206 |
+
except FileNotFoundError:
|
207 |
+
print(f"❌ Model file not found: {model_path}")
|
208 |
+
print(" Please update the model_path variable with your actual model file")
|
209 |
+
except Exception as e:
|
210 |
+
print(f"❌ Error: {e}")
|
211 |
+
|
212 |
+
|
213 |
+
if __name__ == "__main__":
|
214 |
+
test_detector_example()
|
simple_gradio_app.py
ADDED
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Simple Gradio Application for Anomaly Detection Testing
|
3 |
+
Shows embedding analysis instead of reconstructed images
|
4 |
+
"""
|
5 |
+
|
6 |
+
import gradio as gr
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
from PIL import Image
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
import seaborn as sns
|
12 |
+
from scipy import stats
|
13 |
+
import io
|
14 |
+
import base64
|
15 |
+
from simple_anomaly_detector import SimpleAnomalyDetector
|
16 |
+
from image_corruption_utils import corrupt_image
|
17 |
+
|
18 |
+
|
19 |
+
# Global variables to store models
|
20 |
+
models = {
|
21 |
+
"Daudon_MIX": "models/Daudon_MIX/best_autoencoder_Daudon_MIX.pth",
|
22 |
+
"Daudon_SEC": "models/Daudon_SEC/best_autoencoder_Daudon_SEC.pth",
|
23 |
+
"Daudon_SUR": "models/Daudon_SUR/best_autoencoder_Daudon_SUR.pth"
|
24 |
+
}
|
25 |
+
|
26 |
+
current_detector = None
|
27 |
+
current_model_name = None
|
28 |
+
|
29 |
+
|
30 |
+
def load_model(model_name):
|
31 |
+
"""Load the selected model"""
|
32 |
+
global current_detector, current_model_name
|
33 |
+
|
34 |
+
try:
|
35 |
+
if model_name != current_model_name:
|
36 |
+
print(f"Loading model: {model_name}")
|
37 |
+
model_path = models[model_name]
|
38 |
+
current_detector = SimpleAnomalyDetector(model_path)
|
39 |
+
current_model_name = model_name
|
40 |
+
return f"✅ Model {model_name} loaded!"
|
41 |
+
return f"✅ Model {model_name} already loaded"
|
42 |
+
except Exception as e:
|
43 |
+
return f"❌ Error loading {model_name}: {str(e)}"
|
44 |
+
|
45 |
+
|
46 |
+
def get_embedding_and_stats(image):
|
47 |
+
"""Get embedding from autoencoder and calculate statistics"""
|
48 |
+
try:
|
49 |
+
from torchvision import transforms
|
50 |
+
import config
|
51 |
+
|
52 |
+
# Get image size
|
53 |
+
if isinstance(config.IMAGE_SIZE, tuple):
|
54 |
+
target_size = config.IMAGE_SIZE
|
55 |
+
else:
|
56 |
+
target_size = (config.IMAGE_SIZE, config.IMAGE_SIZE)
|
57 |
+
|
58 |
+
# Preprocess
|
59 |
+
image_pil = image.convert('RGB').resize(target_size, Image.LANCZOS)
|
60 |
+
transform = transforms.Compose([
|
61 |
+
transforms.ToTensor(),
|
62 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
63 |
+
])
|
64 |
+
image_tensor = transform(image_pil).unsqueeze(0).to(current_detector.device)
|
65 |
+
|
66 |
+
# Get embedding (latent representation)
|
67 |
+
with torch.no_grad():
|
68 |
+
_, embedding = current_detector.model(image_tensor)
|
69 |
+
|
70 |
+
# Convert to numpy for analysis
|
71 |
+
embedding_np = embedding.squeeze(0).cpu().numpy().flatten()
|
72 |
+
|
73 |
+
# Calculate statistics
|
74 |
+
stats_dict = {
|
75 |
+
'mean': float(np.mean(embedding_np)),
|
76 |
+
'median': float(np.median(embedding_np)),
|
77 |
+
'std': float(np.std(embedding_np)),
|
78 |
+
'min': float(np.min(embedding_np)),
|
79 |
+
'max': float(np.max(embedding_np)),
|
80 |
+
'q25': float(np.percentile(embedding_np, 25)),
|
81 |
+
'q75': float(np.percentile(embedding_np, 75)),
|
82 |
+
'skewness': float(stats.skew(embedding_np)),
|
83 |
+
'kurtosis': float(stats.kurtosis(embedding_np)),
|
84 |
+
'variance': float(np.var(embedding_np)),
|
85 |
+
'range': float(np.max(embedding_np) - np.min(embedding_np)),
|
86 |
+
'iqr': float(np.percentile(embedding_np, 75) - np.percentile(embedding_np, 25))
|
87 |
+
}
|
88 |
+
|
89 |
+
# Create visualization
|
90 |
+
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
|
91 |
+
fig.suptitle(f'Embedding Analysis (Dimension: {len(embedding_np)})', fontsize=16)
|
92 |
+
|
93 |
+
# Histogram
|
94 |
+
axes[0, 0].hist(embedding_np, bins=50, alpha=0.7, color='skyblue', edgecolor='black')
|
95 |
+
axes[0, 0].set_title('Distribution Histogram')
|
96 |
+
axes[0, 0].set_xlabel('Embedding Values')
|
97 |
+
axes[0, 0].set_ylabel('Frequency')
|
98 |
+
axes[0, 0].grid(True, alpha=0.3)
|
99 |
+
|
100 |
+
# Box plot
|
101 |
+
axes[0, 1].boxplot(embedding_np, vert=True)
|
102 |
+
axes[0, 1].set_title('Box Plot')
|
103 |
+
axes[0, 1].set_ylabel('Embedding Values')
|
104 |
+
axes[0, 1].grid(True, alpha=0.3)
|
105 |
+
|
106 |
+
# Q-Q plot (normal distribution)
|
107 |
+
stats.probplot(embedding_np, dist="norm", plot=axes[1, 0])
|
108 |
+
axes[1, 0].set_title('Q-Q Plot (Normal Distribution)')
|
109 |
+
axes[1, 0].grid(True, alpha=0.3)
|
110 |
+
|
111 |
+
# Embedding values plot
|
112 |
+
axes[1, 1].plot(embedding_np, alpha=0.7, color='red', linewidth=1)
|
113 |
+
axes[1, 1].set_title('Embedding Values Sequence')
|
114 |
+
axes[1, 1].set_xlabel('Dimension Index')
|
115 |
+
axes[1, 1].set_ylabel('Value')
|
116 |
+
axes[1, 1].grid(True, alpha=0.3)
|
117 |
+
|
118 |
+
plt.tight_layout()
|
119 |
+
|
120 |
+
# Convert plot to image
|
121 |
+
buf = io.BytesIO()
|
122 |
+
plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
|
123 |
+
buf.seek(0)
|
124 |
+
plot_image = Image.open(buf)
|
125 |
+
plt.close()
|
126 |
+
|
127 |
+
return embedding_np, stats_dict, plot_image
|
128 |
+
|
129 |
+
except Exception as e:
|
130 |
+
print(f"Error in embedding analysis: {e}")
|
131 |
+
return None, {}, None
|
132 |
+
|
133 |
+
|
134 |
+
def format_stats_text(stats_dict):
|
135 |
+
"""Format statistics into readable text"""
|
136 |
+
if not stats_dict:
|
137 |
+
return "❌ Error calculating statistics"
|
138 |
+
|
139 |
+
text = f"""📊 EMBEDDING STATISTICS
|
140 |
+
|
141 |
+
🎯 Central Tendency:
|
142 |
+
Mean: {stats_dict['mean']:.6f}
|
143 |
+
Median: {stats_dict['median']:.6f}
|
144 |
+
|
145 |
+
📏 Spread:
|
146 |
+
Std Dev: {stats_dict['std']:.6f}
|
147 |
+
Variance: {stats_dict['variance']:.6f}
|
148 |
+
Range: {stats_dict['range']:.6f}
|
149 |
+
IQR: {stats_dict['iqr']:.6f}
|
150 |
+
|
151 |
+
📈 Extremes:
|
152 |
+
Min: {stats_dict['min']:.6f}
|
153 |
+
Max: {stats_dict['max']:.6f}
|
154 |
+
Q25: {stats_dict['q25']:.6f}
|
155 |
+
Q75: {stats_dict['q75']:.6f}
|
156 |
+
|
157 |
+
🔄 Shape:
|
158 |
+
Skewness: {stats_dict['skewness']:.6f}
|
159 |
+
Kurtosis: {stats_dict['kurtosis']:.6f}
|
160 |
+
|
161 |
+
"""
|
162 |
+
|
163 |
+
return text
|
164 |
+
|
165 |
+
|
166 |
+
def classify_image(reconstruction_error, threshold):
|
167 |
+
"""Classify image as corrupted or clean based on threshold"""
|
168 |
+
is_corrupted = reconstruction_error > threshold
|
169 |
+
confidence = abs(reconstruction_error - threshold) / threshold * 100
|
170 |
+
|
171 |
+
if is_corrupted:
|
172 |
+
classification = "🚨 CORRUPTED/ANOMALOUS"
|
173 |
+
color_indicator = "🔴"
|
174 |
+
explanation = f"Reconstruction error ({reconstruction_error:.6f}) > Threshold ({threshold:.6f})"
|
175 |
+
else:
|
176 |
+
classification = "✅ CLEAN/NORMAL"
|
177 |
+
color_indicator = "🟢"
|
178 |
+
explanation = f"Reconstruction error ({reconstruction_error:.6f}) ≤ Threshold ({threshold:.6f})"
|
179 |
+
|
180 |
+
# Calculate how far from threshold (as percentage)
|
181 |
+
distance_pct = (reconstruction_error - threshold) / threshold * 100
|
182 |
+
|
183 |
+
classification_text = f"""🎯 ANOMALY CLASSIFICATION
|
184 |
+
|
185 |
+
{color_indicator} Status: {classification}
|
186 |
+
|
187 |
+
📊 Details:
|
188 |
+
Reconstruction Error: {reconstruction_error:.6f}
|
189 |
+
Threshold: {threshold:.6f}
|
190 |
+
Distance from Threshold: {distance_pct:+.2f}%
|
191 |
+
|
192 |
+
📝 Explanation:
|
193 |
+
{explanation}
|
194 |
+
|
195 |
+
💡 Confidence Indicator:
|
196 |
+
• Distance > 50%: High confidence
|
197 |
+
• Distance 10-50%: Medium confidence
|
198 |
+
• Distance < 10%: Low confidence (near threshold)
|
199 |
+
|
200 |
+
🎚️ Current Distance: {abs(distance_pct):.2f}% ({'High' if abs(distance_pct) > 50 else 'Medium' if abs(distance_pct) > 10 else 'Low'} confidence)"""
|
201 |
+
|
202 |
+
return classification_text, is_corrupted
|
203 |
+
|
204 |
+
|
205 |
+
def process_image(model_name, image, corruption_type, intensity, threshold):
|
206 |
+
"""Main processing function"""
|
207 |
+
try:
|
208 |
+
# Load model
|
209 |
+
load_status = load_model(model_name)
|
210 |
+
if "❌" in load_status:
|
211 |
+
return None, None, load_status, 0.0, "", ""
|
212 |
+
|
213 |
+
if image is None:
|
214 |
+
return None, None, "❌ Please upload an image", 0.0, "", ""
|
215 |
+
|
216 |
+
# Apply corruption
|
217 |
+
if corruption_type == "none":
|
218 |
+
corrupted_image = image.copy()
|
219 |
+
corruption_info = "No corruption applied"
|
220 |
+
else:
|
221 |
+
corrupted_image = corrupt_image(image, corruption_type, intensity)
|
222 |
+
corruption_info = f"Applied {corruption_type} corruption (intensity: {intensity})"
|
223 |
+
|
224 |
+
# Calculate reconstruction error
|
225 |
+
error = current_detector.calculate_reconstruction_error(corrupted_image)
|
226 |
+
|
227 |
+
# Get embedding and statistics
|
228 |
+
embedding, stats_dict, plot_image = get_embedding_and_stats(corrupted_image)
|
229 |
+
|
230 |
+
# Format statistics text
|
231 |
+
stats_text = format_stats_text(stats_dict)
|
232 |
+
|
233 |
+
# Classify image based on threshold
|
234 |
+
classification_text, is_corrupted = classify_image(error, threshold)
|
235 |
+
|
236 |
+
# Status message
|
237 |
+
status = f"""✅ Processing complete!
|
238 |
+
📊 Model: {model_name}
|
239 |
+
🔧 {corruption_info}
|
240 |
+
📈 Reconstruction Error: {error:.6f}
|
241 |
+
🎚️ Threshold: {threshold:.6f}
|
242 |
+
🎯 Classification: {'CORRUPTED' if is_corrupted else 'CLEAN'}
|
243 |
+
🧠 Embedding Dimension: {len(embedding) if embedding is not None else 'N/A'}
|
244 |
+
💡 Higher error = more anomalous"""
|
245 |
+
|
246 |
+
return corrupted_image, plot_image, status, error, stats_text, classification_text
|
247 |
+
|
248 |
+
except Exception as e:
|
249 |
+
error_msg = f"❌ Error: {str(e)}"
|
250 |
+
return None, None, error_msg, 0.0, "", ""
|
251 |
+
|
252 |
+
|
253 |
+
# Create interface
|
254 |
+
def create_interface():
|
255 |
+
with gr.Blocks(title="Anomaly Detection Tester") as demo:
|
256 |
+
gr.Markdown("# 🔍 Federated Autoencoder for Kidney Stone Image Corruption Detection")
|
257 |
+
gr.Markdown("Upload an image, analyze its latent representation, and classify it as corrupted or clean using a threshold.")
|
258 |
+
|
259 |
+
with gr.Row():
|
260 |
+
with gr.Column(scale=1):
|
261 |
+
gr.Markdown("### ⚙️ Model & Corruption Settings")
|
262 |
+
|
263 |
+
model_dropdown = gr.Dropdown(
|
264 |
+
choices=list(models.keys()),
|
265 |
+
value="Daudon_MIX",
|
266 |
+
label="🤖 Select Model"
|
267 |
+
)
|
268 |
+
|
269 |
+
corruption_dropdown = gr.Dropdown(
|
270 |
+
choices=["none", "noise", "blur", "brightness", "contrast", "saturation", "random"],
|
271 |
+
value="none",
|
272 |
+
label="🔧 Corruption Type"
|
273 |
+
)
|
274 |
+
|
275 |
+
intensity_slider = gr.Slider(
|
276 |
+
minimum=0.1,
|
277 |
+
maximum=3.0,
|
278 |
+
value=1.0,
|
279 |
+
step=0.1,
|
280 |
+
label="💪 Corruption Intensity"
|
281 |
+
)
|
282 |
+
|
283 |
+
gr.Markdown("### 🎚️ Classification Settings")
|
284 |
+
|
285 |
+
threshold_slider = gr.Slider(
|
286 |
+
minimum=0.1,
|
287 |
+
maximum=3.0,
|
288 |
+
value=1.0,
|
289 |
+
step=0.1,
|
290 |
+
label="🎯 Anomaly Threshold (Reconstruction Error)"
|
291 |
+
)
|
292 |
+
|
293 |
+
gr.Markdown("### 📸 Image Input")
|
294 |
+
|
295 |
+
image_input = gr.Image(type="pil", label="Upload Image")
|
296 |
+
|
297 |
+
# Add examples section
|
298 |
+
gr.Markdown("### 📁 Example Images")
|
299 |
+
|
300 |
+
# You can specify your example image paths here
|
301 |
+
example_images = [
|
302 |
+
["example_imgs/TypeIa_LaosN°15_Image21-25.png", "Clean Daudon MIX-Subtype_Ia"],
|
303 |
+
["example_imgs/72222-SectionIVa+WK maj_0009-60.png", "Clean Daudon MIX-Subtype_IVa"],
|
304 |
+
["example_imgs/TypeIVa2_N47583_Notteb-11.png", "Clean Daudon MIX-Subtype_IVa2"],
|
305 |
+
["example_imgs/typIVc_IVbsectbis-43.png", "Clean Daudon MIX-Subtype_IVc"],
|
306 |
+
["example_imgs/TypeIVd_Sect_LC3373-65.png", "Clean Daudon MIX-Subtype_IVd"],
|
307 |
+
["example_imgs/Section_Va_72845-3-18.png", "Clean Daudon MIX-Subtype_Va"],
|
308 |
+
]
|
309 |
+
|
310 |
+
examples_component = gr.Examples(
|
311 |
+
examples=example_images,
|
312 |
+
inputs=image_input,
|
313 |
+
label="Daudon MIX Example Clean Images",
|
314 |
+
examples_per_page=6,
|
315 |
+
cache_examples=False
|
316 |
+
)
|
317 |
+
|
318 |
+
process_btn = gr.Button("🚀 Analyze & Classify", variant="primary", size="lg")
|
319 |
+
|
320 |
+
with gr.Column(scale=1):
|
321 |
+
gr.Markdown("### 📊 Results")
|
322 |
+
|
323 |
+
status_output = gr.Textbox(label="📋 Status", lines=8)
|
324 |
+
error_output = gr.Number(label="📈 Reconstruction Error", precision=6)
|
325 |
+
|
326 |
+
corrupted_output = gr.Image(label="🔧 Input Image (Corrupted)")
|
327 |
+
|
328 |
+
|
329 |
+
with gr.Row():
|
330 |
+
embedding_plot = gr.Image(label="🧠 Embedding Analysis")
|
331 |
+
|
332 |
+
with gr.Row():
|
333 |
+
stats_output = gr.Textbox(label="📊 Embedding Statistics", lines=20)
|
334 |
+
classification_output = gr.Textbox(label="🎯 Classification Result", lines=15)
|
335 |
+
|
336 |
+
# Connect the button
|
337 |
+
process_btn.click(
|
338 |
+
fn=process_image,
|
339 |
+
inputs=[model_dropdown, image_input, corruption_dropdown, intensity_slider, threshold_slider],
|
340 |
+
outputs=[corrupted_output, embedding_plot, status_output, error_output, stats_output, classification_output]
|
341 |
+
)
|
342 |
+
|
343 |
+
|
344 |
+
return demo
|
345 |
+
|
346 |
+
|
347 |
+
if __name__ == "__main__":
|
348 |
+
print("🚀 Starting Embedding Analysis App...")
|
349 |
+
|
350 |
+
demo = create_interface()
|
351 |
+
demo.launch(
|
352 |
+
server_name="127.0.0.1",
|
353 |
+
server_port=7860,
|
354 |
+
share=False,
|
355 |
+
debug=False,
|
356 |
+
show_error=True
|
357 |
+
)
|
utils/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .data_utils import (
|
2 |
+
create_client_dataloaders,
|
3 |
+
create_global_test_loader,
|
4 |
+
ImageCorruption,
|
5 |
+
KidneyStoneDataset
|
6 |
+
)
|
7 |
+
|
8 |
+
__all__ = [
|
9 |
+
'create_client_dataloaders',
|
10 |
+
'create_global_test_loader',
|
11 |
+
'ImageCorruption',
|
12 |
+
'KidneyStoneDataset'
|
13 |
+
]
|
utils/data_utils.py
ADDED
@@ -0,0 +1,416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Data utilities for federated autoencoder training
|
3 |
+
"""
|
4 |
+
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
import numpy as np
|
8 |
+
from PIL import Image, ImageFilter, ImageEnhance
|
9 |
+
import torch
|
10 |
+
from torch.utils.data import Dataset, DataLoader
|
11 |
+
from torchvision import transforms
|
12 |
+
from sklearn.model_selection import train_test_split
|
13 |
+
from collections import Counter
|
14 |
+
import config
|
15 |
+
|
16 |
+
|
17 |
+
class ImageCorruption:
|
18 |
+
"""Class to handle various image corruptions"""
|
19 |
+
|
20 |
+
def __init__(self, corruption_prob=0.1):
|
21 |
+
self.corruption_prob = corruption_prob
|
22 |
+
|
23 |
+
def gaussian_noise(self, image):
|
24 |
+
"""Add Gaussian noise to image"""
|
25 |
+
if random.random() < self.corruption_prob:
|
26 |
+
noise = torch.randn_like(image) * 0.1
|
27 |
+
image = torch.clamp(image + noise, 0, 1)
|
28 |
+
return image
|
29 |
+
|
30 |
+
def salt_pepper_noise(self, image):
|
31 |
+
"""Add salt and pepper noise"""
|
32 |
+
if random.random() < self.corruption_prob:
|
33 |
+
noise = torch.rand_like(image)
|
34 |
+
salt = noise > 0.95
|
35 |
+
pepper = noise < 0.05
|
36 |
+
image[salt] = 1.0
|
37 |
+
image[pepper] = 0.0
|
38 |
+
return image
|
39 |
+
|
40 |
+
def blur(self, image):
|
41 |
+
"""Apply blur to image"""
|
42 |
+
if random.random() < self.corruption_prob:
|
43 |
+
# Convert to PIL for blur operation
|
44 |
+
if isinstance(image, torch.Tensor):
|
45 |
+
image_pil = transforms.ToPILImage()(image)
|
46 |
+
image_pil = image_pil.filter(ImageFilter.GaussianBlur(radius=random.uniform(0.5, 2.0)))
|
47 |
+
image = transforms.ToTensor()(image_pil)
|
48 |
+
return image
|
49 |
+
|
50 |
+
def brightness_change(self, image):
|
51 |
+
"""Change brightness of image"""
|
52 |
+
if random.random() < self.corruption_prob:
|
53 |
+
factor = random.uniform(0.5, 1.5)
|
54 |
+
image = torch.clamp(image * factor, 0, 1)
|
55 |
+
return image
|
56 |
+
|
57 |
+
def contrast_change(self, image):
|
58 |
+
"""Change contrast of image"""
|
59 |
+
if random.random() < self.corruption_prob:
|
60 |
+
mean = image.mean()
|
61 |
+
factor = random.uniform(0.5, 1.5)
|
62 |
+
image = torch.clamp((image - mean) * factor + mean, 0, 1)
|
63 |
+
return image
|
64 |
+
|
65 |
+
def apply_random_corruption(self, image):
|
66 |
+
"""Apply a random corruption to the image"""
|
67 |
+
corruptions = [
|
68 |
+
self.gaussian_noise,
|
69 |
+
self.salt_pepper_noise,
|
70 |
+
self.blur,
|
71 |
+
self.brightness_change,
|
72 |
+
self.contrast_change
|
73 |
+
]
|
74 |
+
|
75 |
+
corruption_func = random.choice(corruptions)
|
76 |
+
return corruption_func(image)
|
77 |
+
|
78 |
+
|
79 |
+
class KidneyStoneDataset(Dataset):
|
80 |
+
"""Custom dataset for kidney stone images"""
|
81 |
+
|
82 |
+
def __init__(self, image_paths, labels, transform=None, corruption_prob=0.0):
|
83 |
+
self.image_paths = image_paths
|
84 |
+
self.labels = labels
|
85 |
+
self.transform = transform
|
86 |
+
self.corruption = ImageCorruption(corruption_prob)
|
87 |
+
|
88 |
+
def __len__(self):
|
89 |
+
return len(self.image_paths)
|
90 |
+
|
91 |
+
def __getitem__(self, idx):
|
92 |
+
image_path = self.image_paths[idx]
|
93 |
+
label = self.labels[idx]
|
94 |
+
|
95 |
+
# Load image
|
96 |
+
image = Image.open(image_path).convert('RGB')
|
97 |
+
|
98 |
+
if self.transform:
|
99 |
+
image = self.transform(image)
|
100 |
+
|
101 |
+
# Apply corruption if specified
|
102 |
+
if self.corruption.corruption_prob > 0:
|
103 |
+
image = self.corruption.apply_random_corruption(image)
|
104 |
+
|
105 |
+
return image, label
|
106 |
+
|
107 |
+
|
108 |
+
def load_dataset_paths(datasets=None, subversions=None):
|
109 |
+
"""Load image paths and labels from specified datasets and subversions"""
|
110 |
+
all_paths = []
|
111 |
+
all_labels = []
|
112 |
+
|
113 |
+
# Use all datasets if none specified
|
114 |
+
if datasets is None:
|
115 |
+
datasets = config.DATASETS
|
116 |
+
|
117 |
+
# Use all subversions if none specified
|
118 |
+
if subversions is None:
|
119 |
+
subversions = config.SUBVERSIONS
|
120 |
+
|
121 |
+
for dataset_name in datasets:
|
122 |
+
dataset_path = os.path.join(config.DATA_ROOT, dataset_name)
|
123 |
+
|
124 |
+
if not os.path.exists(dataset_path):
|
125 |
+
print(f"Warning: Dataset path does not exist: {dataset_path}")
|
126 |
+
continue
|
127 |
+
|
128 |
+
for subversion in subversions:
|
129 |
+
subversion_path = os.path.join(dataset_path, subversion)
|
130 |
+
|
131 |
+
if not os.path.exists(subversion_path):
|
132 |
+
print(f"Warning: Subversion path does not exist: {subversion_path}")
|
133 |
+
continue
|
134 |
+
|
135 |
+
# Load training images (extract class from folder structure)
|
136 |
+
train_path = os.path.join(subversion_path, "train")
|
137 |
+
if os.path.exists(train_path):
|
138 |
+
# Get all class folders in train directory
|
139 |
+
class_folders = [d for d in os.listdir(train_path)
|
140 |
+
if os.path.isdir(os.path.join(train_path, d))]
|
141 |
+
|
142 |
+
for class_folder in class_folders:
|
143 |
+
class_path = os.path.join(train_path, class_folder)
|
144 |
+
|
145 |
+
# Load all images in this class folder
|
146 |
+
for img_file in os.listdir(class_path):
|
147 |
+
if img_file.lower().endswith(('.png', '.jpg', '.jpeg')):
|
148 |
+
img_path = os.path.join(class_path, img_file)
|
149 |
+
all_paths.append(img_path)
|
150 |
+
# Create label with class information: "subversion_class"
|
151 |
+
all_labels.append(f"{subversion}_{class_folder}")
|
152 |
+
|
153 |
+
# Load test images (extract class from folder structure)
|
154 |
+
test_path = os.path.join(subversion_path, "test")
|
155 |
+
if os.path.exists(test_path):
|
156 |
+
# Get all class folders in test directory
|
157 |
+
class_folders = [d for d in os.listdir(test_path)
|
158 |
+
if os.path.isdir(os.path.join(test_path, d))]
|
159 |
+
|
160 |
+
for class_folder in class_folders:
|
161 |
+
class_path = os.path.join(test_path, class_folder)
|
162 |
+
|
163 |
+
# Load all images in this class folder
|
164 |
+
for img_file in os.listdir(class_path):
|
165 |
+
if img_file.lower().endswith(('.png', '.jpg', '.jpeg')):
|
166 |
+
img_path = os.path.join(class_path, img_file)
|
167 |
+
all_paths.append(img_path)
|
168 |
+
# Create label with class information: "subversion_class"
|
169 |
+
all_labels.append(f"{subversion}_{class_folder}")
|
170 |
+
|
171 |
+
print(f"📊 Data loading summary:")
|
172 |
+
print(f" Total images: {len(all_paths)}")
|
173 |
+
print(f" Unique classes found: {len(set(all_labels))}")
|
174 |
+
print(f" Classes: {sorted(set(all_labels))}")
|
175 |
+
|
176 |
+
return all_paths, all_labels
|
177 |
+
|
178 |
+
|
179 |
+
def redistribute_data_evenly(image_paths, labels, num_clients):
|
180 |
+
"""Redistribute data evenly among clients as fallback"""
|
181 |
+
total_samples = len(image_paths)
|
182 |
+
samples_per_client = total_samples // num_clients
|
183 |
+
|
184 |
+
# Shuffle data
|
185 |
+
combined = list(zip(image_paths, labels))
|
186 |
+
np.random.shuffle(combined)
|
187 |
+
|
188 |
+
client_datasets = []
|
189 |
+
for i in range(num_clients):
|
190 |
+
start_idx = i * samples_per_client
|
191 |
+
if i == num_clients - 1: # Last client gets remaining samples
|
192 |
+
end_idx = total_samples
|
193 |
+
else:
|
194 |
+
end_idx = (i + 1) * samples_per_client
|
195 |
+
|
196 |
+
client_data = combined[start_idx:end_idx]
|
197 |
+
if client_data:
|
198 |
+
client_paths, client_labels = zip(*client_data)
|
199 |
+
client_datasets.append((list(client_paths), list(client_labels)))
|
200 |
+
print(f"Client {i} redistributed with {len(client_paths)} samples")
|
201 |
+
|
202 |
+
return client_datasets
|
203 |
+
|
204 |
+
|
205 |
+
def create_non_iid_distribution(image_paths, labels, num_clients, alpha=0.5):
|
206 |
+
"""Create non-IID data distribution using Dirichlet distribution"""
|
207 |
+
|
208 |
+
# Convert labels to numeric
|
209 |
+
unique_labels = list(set(labels))
|
210 |
+
label_to_idx = {label: idx for idx, label in enumerate(unique_labels)}
|
211 |
+
numeric_labels = [label_to_idx[label] for label in labels]
|
212 |
+
|
213 |
+
num_classes = len(unique_labels)
|
214 |
+
|
215 |
+
# Create Dirichlet distribution for each client
|
216 |
+
client_distributions = np.random.dirichlet([alpha] * num_classes, num_clients)
|
217 |
+
|
218 |
+
# Group data by class
|
219 |
+
class_indices = {i: [] for i in range(num_classes)}
|
220 |
+
for idx, label in enumerate(numeric_labels):
|
221 |
+
class_indices[label].append(idx)
|
222 |
+
|
223 |
+
# Distribute data to clients
|
224 |
+
client_data = [[] for _ in range(num_clients)]
|
225 |
+
|
226 |
+
for class_idx in range(num_classes):
|
227 |
+
class_data = class_indices[class_idx]
|
228 |
+
np.random.shuffle(class_data)
|
229 |
+
|
230 |
+
# Calculate how many samples each client gets from this class
|
231 |
+
total_samples = len(class_data)
|
232 |
+
client_samples = (client_distributions[:, class_idx] * total_samples).astype(int)
|
233 |
+
|
234 |
+
# Ensure we don't exceed total samples
|
235 |
+
if client_samples.sum() > total_samples:
|
236 |
+
excess = client_samples.sum() - total_samples
|
237 |
+
client_samples[-1] -= excess
|
238 |
+
|
239 |
+
# Distribute samples
|
240 |
+
start_idx = 0
|
241 |
+
for client_idx, num_samples in enumerate(client_samples):
|
242 |
+
if num_samples > 0:
|
243 |
+
end_idx = start_idx + num_samples
|
244 |
+
client_data[client_idx].extend(class_data[start_idx:end_idx])
|
245 |
+
start_idx = end_idx
|
246 |
+
|
247 |
+
# Convert indices back to paths and labels
|
248 |
+
client_datasets = []
|
249 |
+
for client_idx, client_indices in enumerate(client_data):
|
250 |
+
if len(client_indices) > 0: # Accept any client with at least some data
|
251 |
+
client_paths = [image_paths[i] for i in client_indices]
|
252 |
+
client_labels = [labels[i] for i in client_indices]
|
253 |
+
client_datasets.append((client_paths, client_labels))
|
254 |
+
print(f"Client {client_idx} will have {len(client_indices)} samples")
|
255 |
+
else:
|
256 |
+
print(f"Warning: Client {client_idx} has no samples assigned")
|
257 |
+
|
258 |
+
# If we don't have enough clients, redistribute the data more evenly
|
259 |
+
if len(client_datasets) < num_clients:
|
260 |
+
print(f"Warning: Only {len(client_datasets)} clients have sufficient data. Redistributing...")
|
261 |
+
return redistribute_data_evenly(image_paths, labels, num_clients)
|
262 |
+
|
263 |
+
return client_datasets
|
264 |
+
|
265 |
+
|
266 |
+
def safe_train_test_split(paths, labels, test_size=0.2, random_state=None):
|
267 |
+
"""
|
268 |
+
Safely split data into train/test, handling classes with insufficient samples
|
269 |
+
"""
|
270 |
+
# Count samples per class
|
271 |
+
class_counts = Counter(labels)
|
272 |
+
|
273 |
+
# Check if we can do stratified split
|
274 |
+
min_class_size = min(class_counts.values())
|
275 |
+
can_stratify = min_class_size >= 2
|
276 |
+
|
277 |
+
if can_stratify:
|
278 |
+
try:
|
279 |
+
return train_test_split(
|
280 |
+
paths, labels,
|
281 |
+
test_size=test_size,
|
282 |
+
random_state=random_state,
|
283 |
+
stratify=labels
|
284 |
+
)
|
285 |
+
except ValueError as e:
|
286 |
+
print(f" ⚠️ Stratified split failed: {e}")
|
287 |
+
can_stratify = False
|
288 |
+
|
289 |
+
if not can_stratify:
|
290 |
+
print(f" 📊 Using random split (some classes have <2 samples)")
|
291 |
+
print(f" 📈 Class distribution: {dict(class_counts)}")
|
292 |
+
|
293 |
+
# Use random split without stratification
|
294 |
+
return train_test_split(
|
295 |
+
paths, labels,
|
296 |
+
test_size=test_size,
|
297 |
+
random_state=random_state,
|
298 |
+
stratify=None
|
299 |
+
)
|
300 |
+
|
301 |
+
|
302 |
+
def get_data_transforms():
|
303 |
+
"""Get data transformations for training and testing"""
|
304 |
+
|
305 |
+
train_transform = transforms.Compose([
|
306 |
+
transforms.Resize(config.IMAGE_SIZE),
|
307 |
+
transforms.RandomHorizontalFlip(p=0.5),
|
308 |
+
transforms.RandomRotation(10),
|
309 |
+
transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),
|
310 |
+
transforms.ToTensor(),
|
311 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
312 |
+
])
|
313 |
+
|
314 |
+
test_transform = transforms.Compose([
|
315 |
+
transforms.Resize(config.IMAGE_SIZE),
|
316 |
+
transforms.ToTensor(),
|
317 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
318 |
+
])
|
319 |
+
|
320 |
+
return train_transform, test_transform
|
321 |
+
|
322 |
+
|
323 |
+
def create_client_dataloaders(num_clients, corruption_prob=0.1, alpha=0.5, datasets=None, subversions=None):
|
324 |
+
"""Create data loaders for all clients with non-IID distribution"""
|
325 |
+
|
326 |
+
# Load data from specified datasets and subversions
|
327 |
+
all_paths, all_labels = load_dataset_paths(datasets=datasets, subversions=subversions)
|
328 |
+
|
329 |
+
print(f"Total images loaded: {len(all_paths)}")
|
330 |
+
print(f"Unique labels: {set(all_labels)}")
|
331 |
+
|
332 |
+
if len(all_paths) == 0:
|
333 |
+
raise ValueError("No images found! Please check your dataset paths and subversions.")
|
334 |
+
|
335 |
+
# Create non-IID distribution
|
336 |
+
client_datasets = create_non_iid_distribution(all_paths, all_labels, num_clients, alpha)
|
337 |
+
|
338 |
+
print(f"Created {len(client_datasets)} client datasets")
|
339 |
+
|
340 |
+
# Get transforms
|
341 |
+
train_transform, test_transform = get_data_transforms()
|
342 |
+
|
343 |
+
# Create data loaders for each client
|
344 |
+
client_loaders = []
|
345 |
+
|
346 |
+
for i, (client_paths, client_labels) in enumerate(client_datasets):
|
347 |
+
print(f"Client {i}: {len(client_paths)} samples")
|
348 |
+
|
349 |
+
# Split into train/test for each client using safe splitting
|
350 |
+
train_paths, test_paths, train_labels, test_labels = safe_train_test_split(
|
351 |
+
client_paths, client_labels, test_size=0.2, random_state=config.SEED
|
352 |
+
)
|
353 |
+
|
354 |
+
# Create datasets
|
355 |
+
train_dataset = KidneyStoneDataset(
|
356 |
+
train_paths, train_labels,
|
357 |
+
transform=train_transform,
|
358 |
+
corruption_prob=corruption_prob
|
359 |
+
)
|
360 |
+
|
361 |
+
test_dataset = KidneyStoneDataset(
|
362 |
+
test_paths, test_labels,
|
363 |
+
transform=test_transform,
|
364 |
+
corruption_prob=0.0 # No corruption for test data
|
365 |
+
)
|
366 |
+
|
367 |
+
# Create data loaders
|
368 |
+
train_loader = DataLoader(
|
369 |
+
train_dataset,
|
370 |
+
batch_size=config.BATCH_SIZE,
|
371 |
+
shuffle=True,
|
372 |
+
num_workers=2
|
373 |
+
)
|
374 |
+
|
375 |
+
test_loader = DataLoader(
|
376 |
+
test_dataset,
|
377 |
+
batch_size=config.BATCH_SIZE,
|
378 |
+
shuffle=False,
|
379 |
+
num_workers=2
|
380 |
+
)
|
381 |
+
|
382 |
+
client_loaders.append((train_loader, test_loader))
|
383 |
+
|
384 |
+
return client_loaders
|
385 |
+
|
386 |
+
|
387 |
+
def create_global_test_loader(datasets=None, subversions=None):
|
388 |
+
"""Create a global test loader for evaluation"""
|
389 |
+
|
390 |
+
# Load data from specified datasets and subversions
|
391 |
+
all_paths, all_labels = load_dataset_paths(datasets=datasets, subversions=subversions)
|
392 |
+
|
393 |
+
if len(all_paths) == 0:
|
394 |
+
raise ValueError("No images found for global test loader! Please check your dataset paths and subversions.")
|
395 |
+
|
396 |
+
# Use a subset for global testing with safe splitting
|
397 |
+
_, test_paths, _, test_labels = safe_train_test_split(
|
398 |
+
all_paths, all_labels, test_size=0.1, random_state=config.SEED
|
399 |
+
)
|
400 |
+
|
401 |
+
_, test_transform = get_data_transforms()
|
402 |
+
|
403 |
+
test_dataset = KidneyStoneDataset(
|
404 |
+
test_paths, test_labels,
|
405 |
+
transform=test_transform,
|
406 |
+
corruption_prob=0.0
|
407 |
+
)
|
408 |
+
|
409 |
+
test_loader = DataLoader(
|
410 |
+
test_dataset,
|
411 |
+
batch_size=config.BATCH_SIZE,
|
412 |
+
shuffle=False,
|
413 |
+
num_workers=2
|
414 |
+
)
|
415 |
+
|
416 |
+
return test_loader
|
utils/metrics.py
ADDED
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Metrics utilities for federated autoencoder evaluation
|
3 |
+
"""
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
|
8 |
+
from sklearn.metrics import classification_report, silhouette_score
|
9 |
+
from sklearn.cluster import KMeans
|
10 |
+
from sklearn.preprocessing import LabelEncoder
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
|
14 |
+
def calculate_reconstruction_metrics(model, data_loader, device):
|
15 |
+
"""
|
16 |
+
Calculate reconstruction-based metrics for autoencoder
|
17 |
+
|
18 |
+
Args:
|
19 |
+
model: Trained autoencoder model
|
20 |
+
data_loader: DataLoader with test data
|
21 |
+
device: Device to run evaluation on
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
dict: Dictionary containing reconstruction metrics
|
25 |
+
"""
|
26 |
+
model.eval()
|
27 |
+
|
28 |
+
reconstruction_errors = []
|
29 |
+
total_loss = 0.0
|
30 |
+
num_samples = 0
|
31 |
+
|
32 |
+
criterion = torch.nn.MSELoss(reduction='none')
|
33 |
+
|
34 |
+
with torch.no_grad():
|
35 |
+
for data, labels in data_loader:
|
36 |
+
data = data.to(device)
|
37 |
+
|
38 |
+
# Forward pass
|
39 |
+
reconstructed, latent = model(data)
|
40 |
+
|
41 |
+
# Calculate per-sample reconstruction error
|
42 |
+
batch_errors = criterion(reconstructed, data).view(data.size(0), -1).mean(dim=1)
|
43 |
+
reconstruction_errors.extend(batch_errors.cpu().numpy())
|
44 |
+
|
45 |
+
# Calculate total loss
|
46 |
+
total_loss += F.mse_loss(reconstructed, data).item() * data.size(0)
|
47 |
+
num_samples += data.size(0)
|
48 |
+
|
49 |
+
reconstruction_errors = np.array(reconstruction_errors)
|
50 |
+
|
51 |
+
# Calculate reconstruction statistics
|
52 |
+
avg_loss = total_loss / num_samples
|
53 |
+
avg_reconstruction_error = np.mean(reconstruction_errors)
|
54 |
+
std_reconstruction_error = np.std(reconstruction_errors)
|
55 |
+
median_reconstruction_error = np.median(reconstruction_errors)
|
56 |
+
|
57 |
+
# Calculate reconstruction quality metrics (lower is better)
|
58 |
+
# Use percentiles to define "good" vs "poor" reconstruction
|
59 |
+
percentile_25 = np.percentile(reconstruction_errors, 25)
|
60 |
+
percentile_75 = np.percentile(reconstruction_errors, 75)
|
61 |
+
|
62 |
+
# Define good reconstruction as bottom 25% of errors
|
63 |
+
good_reconstruction = reconstruction_errors <= percentile_25
|
64 |
+
poor_reconstruction = reconstruction_errors >= percentile_75
|
65 |
+
|
66 |
+
# For autoencoder evaluation, we'll use a more meaningful approach:
|
67 |
+
# Compare reconstruction quality across different error thresholds
|
68 |
+
|
69 |
+
# Method 1: Use median as threshold (more stable than mean)
|
70 |
+
median_threshold = np.median(reconstruction_errors)
|
71 |
+
better_than_median = (reconstruction_errors <= median_threshold).astype(int)
|
72 |
+
|
73 |
+
# Method 2: Use a stricter threshold (25th percentile) for "good" reconstructions
|
74 |
+
strict_threshold = np.percentile(reconstruction_errors, 25)
|
75 |
+
high_quality = (reconstruction_errors <= strict_threshold).astype(int)
|
76 |
+
|
77 |
+
# Calculate "precision" as: how many predicted good are actually good
|
78 |
+
# This is more like "consistency" - if we predict good, how often is it actually good?
|
79 |
+
|
80 |
+
# For binary classification metrics, we need to define what we're classifying
|
81 |
+
# Let's classify: "Is this reconstruction better than average?"
|
82 |
+
|
83 |
+
# Ground truth: better than median (50% of samples)
|
84 |
+
true_better_than_median = better_than_median
|
85 |
+
|
86 |
+
# Prediction: better than 40th percentile (slightly more lenient)
|
87 |
+
prediction_threshold = np.percentile(reconstruction_errors, 40)
|
88 |
+
predicted_better = (reconstruction_errors <= prediction_threshold).astype(int)
|
89 |
+
|
90 |
+
# Calculate metrics - but note these are somewhat artificial for autoencoders
|
91 |
+
accuracy = accuracy_score(true_better_than_median, predicted_better)
|
92 |
+
precision = precision_score(true_better_than_median, predicted_better, average='binary', zero_division=0)
|
93 |
+
recall = recall_score(true_better_than_median, predicted_better, average='binary', zero_division=0)
|
94 |
+
f1 = f1_score(true_better_than_median, predicted_better, average='binary', zero_division=0)
|
95 |
+
|
96 |
+
# Add a note about what these metrics mean
|
97 |
+
classification_note = (
|
98 |
+
"Note: Classification metrics compare 40th vs 50th percentile thresholds. "
|
99 |
+
"Perfect scores may indicate threshold alignment rather than model quality."
|
100 |
+
)
|
101 |
+
|
102 |
+
return {
|
103 |
+
'loss': avg_loss,
|
104 |
+
'reconstruction_error': avg_reconstruction_error,
|
105 |
+
'reconstruction_std': std_reconstruction_error,
|
106 |
+
'reconstruction_median': median_reconstruction_error,
|
107 |
+
'reconstruction_25th': percentile_25,
|
108 |
+
'reconstruction_75th': percentile_75,
|
109 |
+
'accuracy': accuracy,
|
110 |
+
'precision': precision,
|
111 |
+
'recall': recall,
|
112 |
+
'f1_score': f1,
|
113 |
+
'num_samples': num_samples,
|
114 |
+
'good_reconstructions': np.sum(good_reconstruction),
|
115 |
+
'poor_reconstructions': np.sum(poor_reconstruction),
|
116 |
+
'classification_note': classification_note,
|
117 |
+
'better_than_median_count': np.sum(better_than_median),
|
118 |
+
'high_quality_count': np.sum(high_quality)
|
119 |
+
}
|
120 |
+
|
121 |
+
|
122 |
+
def calculate_latent_classification_metrics(model, data_loader, device):
|
123 |
+
"""
|
124 |
+
Calculate classification metrics using latent space representations
|
125 |
+
|
126 |
+
Args:
|
127 |
+
model: Trained autoencoder model
|
128 |
+
data_loader: DataLoader with test data
|
129 |
+
device: Device to run evaluation on
|
130 |
+
|
131 |
+
Returns:
|
132 |
+
dict: Dictionary containing latent space metrics
|
133 |
+
"""
|
134 |
+
model.eval()
|
135 |
+
|
136 |
+
latent_features = []
|
137 |
+
true_labels = []
|
138 |
+
|
139 |
+
with torch.no_grad():
|
140 |
+
for data, labels in data_loader:
|
141 |
+
data = data.to(device)
|
142 |
+
|
143 |
+
# Get latent representations
|
144 |
+
_, latent = model(data)
|
145 |
+
latent_features.append(latent.cpu().numpy())
|
146 |
+
true_labels.extend(labels)
|
147 |
+
|
148 |
+
# Combine all latent features
|
149 |
+
latent_features = np.vstack(latent_features)
|
150 |
+
|
151 |
+
# Encode string labels to numeric
|
152 |
+
label_encoder = LabelEncoder()
|
153 |
+
numeric_labels = label_encoder.fit_transform(true_labels)
|
154 |
+
unique_labels = np.unique(numeric_labels)
|
155 |
+
n_classes = len(unique_labels)
|
156 |
+
|
157 |
+
print(f" 🔍 Latent analysis: {n_classes} unique classes found")
|
158 |
+
print(f" 📊 Class distribution: {dict(zip(label_encoder.classes_, np.bincount(numeric_labels)))}")
|
159 |
+
|
160 |
+
if n_classes == 1:
|
161 |
+
# Single class case - no meaningful classification possible
|
162 |
+
return {
|
163 |
+
'latent_accuracy': 0.0, # No classification possible
|
164 |
+
'latent_precision': 0.0,
|
165 |
+
'latent_recall': 0.0,
|
166 |
+
'latent_f1_score': 0.0,
|
167 |
+
'silhouette_score': 0.0, # No clustering possible
|
168 |
+
'n_clusters': n_classes,
|
169 |
+
'latent_dim': latent_features.shape[1],
|
170 |
+
'cluster_quality': 'single_class'
|
171 |
+
}
|
172 |
+
|
173 |
+
# Perform clustering
|
174 |
+
try:
|
175 |
+
# Use the actual number of classes for clustering
|
176 |
+
kmeans = KMeans(n_clusters=n_classes, random_state=42, n_init=10)
|
177 |
+
cluster_predictions = kmeans.fit_predict(latent_features)
|
178 |
+
|
179 |
+
# Calculate silhouette score for cluster quality
|
180 |
+
if n_classes > 1 and len(set(cluster_predictions)) > 1:
|
181 |
+
silhouette = silhouette_score(latent_features, cluster_predictions)
|
182 |
+
else:
|
183 |
+
silhouette = 0.0
|
184 |
+
|
185 |
+
# For meaningful classification metrics, we need to align cluster labels with true labels
|
186 |
+
# This is a complex problem, so we'll use a simpler approach:
|
187 |
+
# Calculate how well the clustering separates the true classes
|
188 |
+
|
189 |
+
# Method 1: Direct comparison (may not be meaningful due to label permutation)
|
190 |
+
accuracy_direct = accuracy_score(numeric_labels, cluster_predictions)
|
191 |
+
|
192 |
+
# Method 2: Best possible alignment between clusters and true labels
|
193 |
+
try:
|
194 |
+
from scipy.optimize import linear_sum_assignment
|
195 |
+
from sklearn.metrics import confusion_matrix
|
196 |
+
|
197 |
+
# Create confusion matrix
|
198 |
+
cm = confusion_matrix(numeric_labels, cluster_predictions)
|
199 |
+
|
200 |
+
# Find best assignment using Hungarian algorithm
|
201 |
+
if cm.shape[0] == cm.shape[1]: # Same number of clusters and classes
|
202 |
+
row_ind, col_ind = linear_sum_assignment(-cm) # Negative for maximization
|
203 |
+
aligned_predictions = np.zeros_like(cluster_predictions)
|
204 |
+
for i, j in zip(row_ind, col_ind):
|
205 |
+
aligned_predictions[cluster_predictions == j] = i
|
206 |
+
|
207 |
+
# Calculate metrics with aligned labels
|
208 |
+
accuracy = accuracy_score(numeric_labels, aligned_predictions)
|
209 |
+
precision = precision_score(numeric_labels, aligned_predictions, average='weighted', zero_division=0)
|
210 |
+
recall = recall_score(numeric_labels, aligned_predictions, average='weighted', zero_division=0)
|
211 |
+
f1 = f1_score(numeric_labels, aligned_predictions, average='weighted', zero_division=0)
|
212 |
+
cluster_quality = 'aligned'
|
213 |
+
else:
|
214 |
+
# Different number of clusters and classes - use direct comparison
|
215 |
+
accuracy = accuracy_direct
|
216 |
+
precision = precision_score(numeric_labels, cluster_predictions, average='weighted', zero_division=0)
|
217 |
+
recall = recall_score(numeric_labels, cluster_predictions, average='weighted', zero_division=0)
|
218 |
+
f1 = f1_score(numeric_labels, cluster_predictions, average='weighted', zero_division=0)
|
219 |
+
cluster_quality = 'unaligned'
|
220 |
+
except ImportError:
|
221 |
+
print(f" ⚠️ scipy not available, using direct comparison")
|
222 |
+
# Fallback to direct comparison without alignment
|
223 |
+
accuracy = accuracy_direct
|
224 |
+
precision = precision_score(numeric_labels, cluster_predictions, average='weighted', zero_division=0)
|
225 |
+
recall = recall_score(numeric_labels, cluster_predictions, average='weighted', zero_division=0)
|
226 |
+
f1 = f1_score(numeric_labels, cluster_predictions, average='weighted', zero_division=0)
|
227 |
+
cluster_quality = 'direct'
|
228 |
+
|
229 |
+
except Exception as e:
|
230 |
+
print(f" ⚠️ Clustering failed: {e}")
|
231 |
+
accuracy = precision = recall = f1 = silhouette = 0.0
|
232 |
+
cluster_quality = 'failed'
|
233 |
+
|
234 |
+
return {
|
235 |
+
'latent_accuracy': accuracy,
|
236 |
+
'latent_precision': precision,
|
237 |
+
'latent_recall': recall,
|
238 |
+
'latent_f1_score': f1,
|
239 |
+
'silhouette_score': silhouette,
|
240 |
+
'n_clusters': n_classes,
|
241 |
+
'latent_dim': latent_features.shape[1],
|
242 |
+
'cluster_quality': cluster_quality
|
243 |
+
}
|
244 |
+
|
245 |
+
|
246 |
+
def calculate_comprehensive_metrics(model, data_loader, device):
|
247 |
+
"""
|
248 |
+
Calculate comprehensive metrics for autoencoder evaluation
|
249 |
+
|
250 |
+
Args:
|
251 |
+
model: Trained autoencoder model
|
252 |
+
data_loader: DataLoader with test data
|
253 |
+
device: Device to run evaluation on
|
254 |
+
|
255 |
+
Returns:
|
256 |
+
dict: Dictionary containing all metrics
|
257 |
+
"""
|
258 |
+
print(f" 🔄 Calculating reconstruction metrics...")
|
259 |
+
recon_metrics = calculate_reconstruction_metrics(model, data_loader, device)
|
260 |
+
|
261 |
+
print(f" 🧠 Calculating latent space metrics...")
|
262 |
+
latent_metrics = calculate_latent_classification_metrics(model, data_loader, device)
|
263 |
+
|
264 |
+
# Combine all metrics
|
265 |
+
comprehensive_metrics = {
|
266 |
+
**recon_metrics,
|
267 |
+
**latent_metrics
|
268 |
+
}
|
269 |
+
|
270 |
+
return comprehensive_metrics
|
271 |
+
|
272 |
+
|
273 |
+
def print_metrics_summary(metrics, subversion_name):
|
274 |
+
"""Print a formatted summary of metrics"""
|
275 |
+
print(f"\n📊 Metrics Summary for {subversion_name}:")
|
276 |
+
print("=" * 60)
|
277 |
+
|
278 |
+
# Reconstruction metrics
|
279 |
+
print(f"🔄 Reconstruction Loss: {metrics['loss']:.6f}")
|
280 |
+
print(f"📏 Reconstruction Error: {metrics['reconstruction_error']:.6f} ± {metrics['reconstruction_std']:.6f}")
|
281 |
+
print(f"📊 Reconstruction Median: {metrics['reconstruction_median']:.6f}")
|
282 |
+
print(f"📈 25th/75th Percentile: {metrics['reconstruction_25th']:.6f} / {metrics['reconstruction_75th']:.6f}")
|
283 |
+
print(f"✅ Good Reconstructions: {metrics['good_reconstructions']}/{metrics['num_samples']} ({100*metrics['good_reconstructions']/metrics['num_samples']:.1f}%)")
|
284 |
+
print(f"❌ Poor Reconstructions: {metrics['poor_reconstructions']}/{metrics['num_samples']} ({100*metrics['poor_reconstructions']/metrics['num_samples']:.1f}%)")
|
285 |
+
print(f"🎯 Better than Median: {metrics['better_than_median_count']}/{metrics['num_samples']} ({100*metrics['better_than_median_count']/metrics['num_samples']:.1f}%)")
|
286 |
+
print(f"⭐ High Quality (top 25%): {metrics['high_quality_count']}/{metrics['num_samples']} ({100*metrics['high_quality_count']/metrics['num_samples']:.1f}%)")
|
287 |
+
|
288 |
+
# Classification metrics with explanation
|
289 |
+
print(f"\n🎯 Reconstruction Classification Metrics:")
|
290 |
+
print(f" Accuracy: {metrics['accuracy']:.4f}")
|
291 |
+
print(f" Precision: {metrics['precision']:.4f}")
|
292 |
+
print(f" Recall: {metrics['recall']:.4f}")
|
293 |
+
print(f" F1-Score: {metrics['f1_score']:.4f}")
|
294 |
+
print(f" ℹ️ {metrics['classification_note']}")
|
295 |
+
|
296 |
+
# Latent space metrics
|
297 |
+
print(f"\n🧠 Latent Space Analysis:")
|
298 |
+
print(f" Latent Accuracy: {metrics['latent_accuracy']:.4f}")
|
299 |
+
print(f" Latent Precision: {metrics['latent_precision']:.4f}")
|
300 |
+
print(f" Latent Recall: {metrics['latent_recall']:.4f}")
|
301 |
+
print(f" Latent F1-Score: {metrics['latent_f1_score']:.4f}")
|
302 |
+
print(f" Silhouette Score: {metrics['silhouette_score']:.4f}")
|
303 |
+
print(f" Clusters Found: {metrics['n_clusters']}")
|
304 |
+
print(f" Latent Dimension: {metrics['latent_dim']}")
|
305 |
+
print(f" Cluster Quality: {metrics['cluster_quality']}")
|
306 |
+
|
307 |
+
# Interpretation guide
|
308 |
+
print(f"\n📖 Interpretation Guide:")
|
309 |
+
print(f" • Reconstruction Loss: Lower = better image reconstruction")
|
310 |
+
print(f" • Latent Accuracy: How well clustering separates kidney stone classes")
|
311 |
+
print(f" • Silhouette Score: Quality of latent space clustering (higher = better)")
|
312 |
+
print(f" • Perfect precision (1.0) in reconstruction metrics may indicate")
|
313 |
+
print(f" threshold alignment rather than exceptional model performance")
|
314 |
+
|
315 |
+
print(f"\n📦 Total Samples: {metrics['num_samples']}")
|
316 |
+
print("=" * 60)
|