Ivanrs commited on
Commit
e75d4ed
·
verified ·
1 Parent(s): ae5898f

Upload 15 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ example_imgs/Section_Va_72845-3-18.png filter=lfs diff=lfs merge=lfs -text
37
+ example_imgs/TypeIa_LaosN°15_Image21-25.png filter=lfs diff=lfs merge=lfs -text
38
+ example_imgs/TypeIVa2_N47583_Notteb-11.png filter=lfs diff=lfs merge=lfs -text
39
+ example_imgs/TypeIVd_Sect_LC3373-65.png filter=lfs diff=lfs merge=lfs -text
40
+ example_imgs/typIVc_IVbsectbis-43.png filter=lfs diff=lfs merge=lfs -text
example_imgs/72222-SectionIVa+WK maj_0009-60.png ADDED
example_imgs/Section_Va_72845-3-18.png ADDED

Git LFS Details

  • SHA256: 77e446c7239e0d9dc169a90a4f1c2642eb77638a461ae317470f88a00cb036ba
  • Pointer size: 131 Bytes
  • Size of remote file: 128 kB
example_imgs/TypeIVa2_N47583_Notteb-11.png ADDED

Git LFS Details

  • SHA256: 1ca75687d778e60f8c5b65cbf785505cbf25bc12f6bcffdf7e9af881426e9243
  • Pointer size: 131 Bytes
  • Size of remote file: 116 kB
example_imgs/TypeIVd_Sect_LC3373-65.png ADDED

Git LFS Details

  • SHA256: e607af6d6285c7d79d1506fe39f9f3a851dcd9a0a6593df357f00b61ba07ebb2
  • Pointer size: 131 Bytes
  • Size of remote file: 122 kB
example_imgs/TypeIa_LaosN°15_Image21-25.png ADDED

Git LFS Details

  • SHA256: c91b28a8d90256bc09442bd3b62751ea241a3ab9522398ba2e4f7d642d76de5a
  • Pointer size: 131 Bytes
  • Size of remote file: 122 kB
example_imgs/typIVc_IVbsectbis-43.png ADDED

Git LFS Details

  • SHA256: cbae5243d74e2f9c0156c804773ad9fe14c62d9b374c9798c5fee95b7bdddf38
  • Pointer size: 131 Bytes
  • Size of remote file: 141 kB
models/Daudon_MIX/best_autoencoder_Daudon_MIX.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b63d58b9d42c9de2b912d2cbd770c1b1b93c591965e44e6d1479ee0caa01007
3
+ size 123579888
models/Daudon_SEC/best_autoencoder_Daudon_SEC.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:687f67252659ffbdde997adb689b24d80b8f77a4baedca2d52cdadbd4d30fa65
3
+ size 123579888
models/Daudon_SUR/best_autoencoder_Daudon_SUR.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e969a304b220177169ff5e46668110cb62fdd4752ffdf4ecb84bbf50d68bff61
3
+ size 123579888
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision>=0.15.0
3
+ flwr>=1.6.0
4
+ numpy>=1.24.0
5
+ Pillow>=9.5.0
6
+ matplotlib>=3.7.0
7
+ scikit-learn>=1.3.0
8
+ tqdm>=4.65.0
simple_anomaly_detector.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simple Anomaly Detector using Reconstruction Error
3
+ A minimal implementation for testing corruption intensity using autoencoder reconstruction error
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import numpy as np
9
+ from PIL import Image
10
+ import torchvision.transforms as transforms
11
+ from typing import Union
12
+ import random
13
+
14
+ from models import Autoencoder
15
+ from utils.data_utils import ImageCorruption
16
+ import config
17
+
18
+
19
+ def apply_corruption(image_tensor: torch.Tensor, corruption_type: str = 'random') -> torch.Tensor:
20
+ """
21
+ Simple function to apply corruption to an image tensor
22
+
23
+ Args:
24
+ image_tensor: Input image tensor (C, H, W)
25
+ corruption_type: Type of corruption ('noise', 'blur', 'brightness', 'contrast', 'random')
26
+
27
+ Returns:
28
+ Corrupted image tensor
29
+ """
30
+ # Create corruption object with 100% probability to ensure corruption is applied
31
+ corruptor = ImageCorruption(corruption_prob=1.0)
32
+
33
+ if corruption_type == 'noise':
34
+ return corruptor.gaussian_noise(image_tensor.clone())
35
+ elif corruption_type == 'blur':
36
+ return corruptor.blur(image_tensor.clone())
37
+ elif corruption_type == 'brightness':
38
+ return corruptor.brightness_change(image_tensor.clone())
39
+ elif corruption_type == 'contrast':
40
+ return corruptor.contrast_change(image_tensor.clone())
41
+ elif corruption_type == 'random':
42
+ return corruptor.apply_random_corruption(image_tensor.clone())
43
+ else:
44
+ raise ValueError(f"Unknown corruption type: {corruption_type}")
45
+
46
+
47
+ class SimpleAnomalyDetector:
48
+ """Simple anomaly detector based on reconstruction error"""
49
+
50
+ def __init__(self, model_path: str):
51
+ """
52
+ Initialize the detector with a trained autoencoder
53
+
54
+ Args:
55
+ model_path: Path to the trained autoencoder (.pth file)
56
+ """
57
+ self.device = torch.device(config.DEVICE)
58
+ self.model = self._load_model(model_path)
59
+ self.criterion = nn.MSELoss()
60
+
61
+ # Image preprocessing - simplified and more robust
62
+ self.transform = transforms.Compose([
63
+ transforms.Resize((config.IMAGE_SIZE, config.IMAGE_SIZE)),
64
+ transforms.ToTensor(),
65
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
66
+ std=[0.229, 0.224, 0.225])
67
+ ])
68
+
69
+ print(f"✅ Anomaly detector ready! Using device: {self.device}")
70
+ print(f"📏 Image size: {config.IMAGE_SIZE}x{config.IMAGE_SIZE}")
71
+
72
+ def _load_model(self, model_path: str) -> Autoencoder:
73
+ """Load the trained autoencoder model"""
74
+ print(f"📥 Loading model from {model_path}")
75
+
76
+ # Load checkpoint (weights_only=False for compatibility with saved metadata)
77
+ checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
78
+
79
+ # Create model with same architecture
80
+ model = Autoencoder(
81
+ input_channels=config.CHANNELS,
82
+ latent_dim=config.LATENT_DIM
83
+ )
84
+
85
+ # Load trained weights
86
+ model.load_state_dict(checkpoint['model_state_dict'])
87
+ model.to(self.device)
88
+ model.eval()
89
+
90
+ return model
91
+
92
+ def calculate_reconstruction_error(self, image: Union[str, Image.Image, torch.Tensor]) -> float:
93
+ """
94
+ Calculate reconstruction error for a single image
95
+
96
+ Args:
97
+ image: Can be:
98
+ - String path to image file
99
+ - PIL Image object
100
+ - PyTorch tensor (C, H, W) or (1, C, H, W)
101
+
102
+ Returns:
103
+ Reconstruction error as a float (higher = more anomalous)
104
+ """
105
+ # Get image size - handle both tuple and integer formats
106
+ if isinstance(config.IMAGE_SIZE, tuple):
107
+ target_size = config.IMAGE_SIZE # (256, 256)
108
+ else:
109
+ target_size = (config.IMAGE_SIZE, config.IMAGE_SIZE)
110
+
111
+ # Convert input to tensor
112
+ if isinstance(image, str):
113
+ # Load from file path
114
+ try:
115
+ image_pil = Image.open(image).convert('RGB')
116
+ # Resize the image properly
117
+ image_pil = image_pil.resize(target_size, Image.LANCZOS)
118
+ image_tensor = transforms.ToTensor()(image_pil)
119
+ # Apply normalization
120
+ normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
121
+ image_tensor = normalize(image_tensor).unsqueeze(0) # Add batch dimension
122
+ except Exception as e:
123
+ raise ValueError(f"Error loading image from {image}: {e}")
124
+
125
+ elif isinstance(image, Image.Image):
126
+ # PIL Image
127
+ try:
128
+ image_pil = image.convert('RGB')
129
+ image_pil = image_pil.resize(target_size, Image.LANCZOS)
130
+ image_tensor = transforms.ToTensor()(image_pil)
131
+ normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
132
+ image_tensor = normalize(image_tensor).unsqueeze(0)
133
+ except Exception as e:
134
+ raise ValueError(f"Error processing PIL Image: {e}")
135
+
136
+ elif isinstance(image, torch.Tensor):
137
+ # PyTorch tensor
138
+ if image.dim() == 3: # (C, H, W)
139
+ image_tensor = image.unsqueeze(0) # Add batch dimension
140
+ elif image.dim() == 4: # (1, C, H, W)
141
+ image_tensor = image
142
+ else:
143
+ raise ValueError(f"Unexpected tensor dimensions: {image.shape}")
144
+ else:
145
+ raise ValueError(f"Unsupported image type: {type(image)}")
146
+
147
+ # Move to device
148
+ image_tensor = image_tensor.to(self.device)
149
+
150
+ # Calculate reconstruction error
151
+ with torch.no_grad():
152
+ reconstructed, _ = self.model(image_tensor)
153
+ error = self.criterion(reconstructed, image_tensor)
154
+
155
+ return error.item()
156
+
157
+
158
+ def test_detector_example():
159
+ """Example usage of the simple anomaly detector"""
160
+
161
+ # You need to specify the path to your trained model
162
+ model_path = "models/All_Datasets_MIX/best_autoencoder_All_Datasets_MIX.pth" # Change this!
163
+
164
+ try:
165
+ # Initialize detector
166
+ detector = SimpleAnomalyDetector(model_path)
167
+
168
+ # Test with some images from your dataset
169
+ from utils.data_utils import create_global_test_loader
170
+
171
+ # Get a test loader
172
+ test_loader = create_global_test_loader(
173
+ datasets=["Michel Daudon (w256 1k v1)", "Jonathan El-Beze (w256 1k v1)"],
174
+ subversions=["MIX"]
175
+ )
176
+
177
+ print("\n🧪 Testing reconstruction errors:")
178
+ print("=" * 50)
179
+
180
+ # Test a few images
181
+ for i, (images, labels) in enumerate(test_loader):
182
+ if i >= 3: # Test only first 3 batches
183
+ break
184
+
185
+ for j in range(min(2, images.size(0))): # Test 2 images per batch
186
+ clean_image = images[j]
187
+
188
+ # Test clean image
189
+ clean_error = detector.calculate_reconstruction_error(clean_image)
190
+
191
+ # Test corrupted versions
192
+ corrupted_noise = apply_corruption(clean_image, 'noise')
193
+ corrupted_blur = apply_corruption(clean_image, 'blur')
194
+
195
+ noise_error = detector.calculate_reconstruction_error(corrupted_noise)
196
+ blur_error = detector.calculate_reconstruction_error(corrupted_blur)
197
+
198
+ print(f"\nImage {i*2 + j + 1} (Class: {labels[j]}):")
199
+ print(f" Clean: {clean_error:.6f}")
200
+ print(f" Noise corrupted: {noise_error:.6f} (x{noise_error/clean_error:.2f})")
201
+ print(f" Blur corrupted: {blur_error:.6f} (x{blur_error/clean_error:.2f})")
202
+
203
+ print(f"\n💡 Usage tip: Higher reconstruction error = more anomalous/corrupted")
204
+ print(f" You can set a threshold (e.g., 0.01) above which images are considered anomalous")
205
+
206
+ except FileNotFoundError:
207
+ print(f"❌ Model file not found: {model_path}")
208
+ print(" Please update the model_path variable with your actual model file")
209
+ except Exception as e:
210
+ print(f"❌ Error: {e}")
211
+
212
+
213
+ if __name__ == "__main__":
214
+ test_detector_example()
simple_gradio_app.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simple Gradio Application for Anomaly Detection Testing
3
+ Shows embedding analysis instead of reconstructed images
4
+ """
5
+
6
+ import gradio as gr
7
+ import torch
8
+ import numpy as np
9
+ from PIL import Image
10
+ import matplotlib.pyplot as plt
11
+ import seaborn as sns
12
+ from scipy import stats
13
+ import io
14
+ import base64
15
+ from simple_anomaly_detector import SimpleAnomalyDetector
16
+ from image_corruption_utils import corrupt_image
17
+
18
+
19
+ # Global variables to store models
20
+ models = {
21
+ "Daudon_MIX": "models/Daudon_MIX/best_autoencoder_Daudon_MIX.pth",
22
+ "Daudon_SEC": "models/Daudon_SEC/best_autoencoder_Daudon_SEC.pth",
23
+ "Daudon_SUR": "models/Daudon_SUR/best_autoencoder_Daudon_SUR.pth"
24
+ }
25
+
26
+ current_detector = None
27
+ current_model_name = None
28
+
29
+
30
+ def load_model(model_name):
31
+ """Load the selected model"""
32
+ global current_detector, current_model_name
33
+
34
+ try:
35
+ if model_name != current_model_name:
36
+ print(f"Loading model: {model_name}")
37
+ model_path = models[model_name]
38
+ current_detector = SimpleAnomalyDetector(model_path)
39
+ current_model_name = model_name
40
+ return f"✅ Model {model_name} loaded!"
41
+ return f"✅ Model {model_name} already loaded"
42
+ except Exception as e:
43
+ return f"❌ Error loading {model_name}: {str(e)}"
44
+
45
+
46
+ def get_embedding_and_stats(image):
47
+ """Get embedding from autoencoder and calculate statistics"""
48
+ try:
49
+ from torchvision import transforms
50
+ import config
51
+
52
+ # Get image size
53
+ if isinstance(config.IMAGE_SIZE, tuple):
54
+ target_size = config.IMAGE_SIZE
55
+ else:
56
+ target_size = (config.IMAGE_SIZE, config.IMAGE_SIZE)
57
+
58
+ # Preprocess
59
+ image_pil = image.convert('RGB').resize(target_size, Image.LANCZOS)
60
+ transform = transforms.Compose([
61
+ transforms.ToTensor(),
62
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
63
+ ])
64
+ image_tensor = transform(image_pil).unsqueeze(0).to(current_detector.device)
65
+
66
+ # Get embedding (latent representation)
67
+ with torch.no_grad():
68
+ _, embedding = current_detector.model(image_tensor)
69
+
70
+ # Convert to numpy for analysis
71
+ embedding_np = embedding.squeeze(0).cpu().numpy().flatten()
72
+
73
+ # Calculate statistics
74
+ stats_dict = {
75
+ 'mean': float(np.mean(embedding_np)),
76
+ 'median': float(np.median(embedding_np)),
77
+ 'std': float(np.std(embedding_np)),
78
+ 'min': float(np.min(embedding_np)),
79
+ 'max': float(np.max(embedding_np)),
80
+ 'q25': float(np.percentile(embedding_np, 25)),
81
+ 'q75': float(np.percentile(embedding_np, 75)),
82
+ 'skewness': float(stats.skew(embedding_np)),
83
+ 'kurtosis': float(stats.kurtosis(embedding_np)),
84
+ 'variance': float(np.var(embedding_np)),
85
+ 'range': float(np.max(embedding_np) - np.min(embedding_np)),
86
+ 'iqr': float(np.percentile(embedding_np, 75) - np.percentile(embedding_np, 25))
87
+ }
88
+
89
+ # Create visualization
90
+ fig, axes = plt.subplots(2, 2, figsize=(12, 10))
91
+ fig.suptitle(f'Embedding Analysis (Dimension: {len(embedding_np)})', fontsize=16)
92
+
93
+ # Histogram
94
+ axes[0, 0].hist(embedding_np, bins=50, alpha=0.7, color='skyblue', edgecolor='black')
95
+ axes[0, 0].set_title('Distribution Histogram')
96
+ axes[0, 0].set_xlabel('Embedding Values')
97
+ axes[0, 0].set_ylabel('Frequency')
98
+ axes[0, 0].grid(True, alpha=0.3)
99
+
100
+ # Box plot
101
+ axes[0, 1].boxplot(embedding_np, vert=True)
102
+ axes[0, 1].set_title('Box Plot')
103
+ axes[0, 1].set_ylabel('Embedding Values')
104
+ axes[0, 1].grid(True, alpha=0.3)
105
+
106
+ # Q-Q plot (normal distribution)
107
+ stats.probplot(embedding_np, dist="norm", plot=axes[1, 0])
108
+ axes[1, 0].set_title('Q-Q Plot (Normal Distribution)')
109
+ axes[1, 0].grid(True, alpha=0.3)
110
+
111
+ # Embedding values plot
112
+ axes[1, 1].plot(embedding_np, alpha=0.7, color='red', linewidth=1)
113
+ axes[1, 1].set_title('Embedding Values Sequence')
114
+ axes[1, 1].set_xlabel('Dimension Index')
115
+ axes[1, 1].set_ylabel('Value')
116
+ axes[1, 1].grid(True, alpha=0.3)
117
+
118
+ plt.tight_layout()
119
+
120
+ # Convert plot to image
121
+ buf = io.BytesIO()
122
+ plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
123
+ buf.seek(0)
124
+ plot_image = Image.open(buf)
125
+ plt.close()
126
+
127
+ return embedding_np, stats_dict, plot_image
128
+
129
+ except Exception as e:
130
+ print(f"Error in embedding analysis: {e}")
131
+ return None, {}, None
132
+
133
+
134
+ def format_stats_text(stats_dict):
135
+ """Format statistics into readable text"""
136
+ if not stats_dict:
137
+ return "❌ Error calculating statistics"
138
+
139
+ text = f"""📊 EMBEDDING STATISTICS
140
+
141
+ 🎯 Central Tendency:
142
+ Mean: {stats_dict['mean']:.6f}
143
+ Median: {stats_dict['median']:.6f}
144
+
145
+ 📏 Spread:
146
+ Std Dev: {stats_dict['std']:.6f}
147
+ Variance: {stats_dict['variance']:.6f}
148
+ Range: {stats_dict['range']:.6f}
149
+ IQR: {stats_dict['iqr']:.6f}
150
+
151
+ 📈 Extremes:
152
+ Min: {stats_dict['min']:.6f}
153
+ Max: {stats_dict['max']:.6f}
154
+ Q25: {stats_dict['q25']:.6f}
155
+ Q75: {stats_dict['q75']:.6f}
156
+
157
+ 🔄 Shape:
158
+ Skewness: {stats_dict['skewness']:.6f}
159
+ Kurtosis: {stats_dict['kurtosis']:.6f}
160
+
161
+ """
162
+
163
+ return text
164
+
165
+
166
+ def classify_image(reconstruction_error, threshold):
167
+ """Classify image as corrupted or clean based on threshold"""
168
+ is_corrupted = reconstruction_error > threshold
169
+ confidence = abs(reconstruction_error - threshold) / threshold * 100
170
+
171
+ if is_corrupted:
172
+ classification = "🚨 CORRUPTED/ANOMALOUS"
173
+ color_indicator = "🔴"
174
+ explanation = f"Reconstruction error ({reconstruction_error:.6f}) > Threshold ({threshold:.6f})"
175
+ else:
176
+ classification = "✅ CLEAN/NORMAL"
177
+ color_indicator = "🟢"
178
+ explanation = f"Reconstruction error ({reconstruction_error:.6f}) ≤ Threshold ({threshold:.6f})"
179
+
180
+ # Calculate how far from threshold (as percentage)
181
+ distance_pct = (reconstruction_error - threshold) / threshold * 100
182
+
183
+ classification_text = f"""🎯 ANOMALY CLASSIFICATION
184
+
185
+ {color_indicator} Status: {classification}
186
+
187
+ 📊 Details:
188
+ Reconstruction Error: {reconstruction_error:.6f}
189
+ Threshold: {threshold:.6f}
190
+ Distance from Threshold: {distance_pct:+.2f}%
191
+
192
+ 📝 Explanation:
193
+ {explanation}
194
+
195
+ 💡 Confidence Indicator:
196
+ • Distance > 50%: High confidence
197
+ • Distance 10-50%: Medium confidence
198
+ • Distance < 10%: Low confidence (near threshold)
199
+
200
+ 🎚️ Current Distance: {abs(distance_pct):.2f}% ({'High' if abs(distance_pct) > 50 else 'Medium' if abs(distance_pct) > 10 else 'Low'} confidence)"""
201
+
202
+ return classification_text, is_corrupted
203
+
204
+
205
+ def process_image(model_name, image, corruption_type, intensity, threshold):
206
+ """Main processing function"""
207
+ try:
208
+ # Load model
209
+ load_status = load_model(model_name)
210
+ if "❌" in load_status:
211
+ return None, None, load_status, 0.0, "", ""
212
+
213
+ if image is None:
214
+ return None, None, "❌ Please upload an image", 0.0, "", ""
215
+
216
+ # Apply corruption
217
+ if corruption_type == "none":
218
+ corrupted_image = image.copy()
219
+ corruption_info = "No corruption applied"
220
+ else:
221
+ corrupted_image = corrupt_image(image, corruption_type, intensity)
222
+ corruption_info = f"Applied {corruption_type} corruption (intensity: {intensity})"
223
+
224
+ # Calculate reconstruction error
225
+ error = current_detector.calculate_reconstruction_error(corrupted_image)
226
+
227
+ # Get embedding and statistics
228
+ embedding, stats_dict, plot_image = get_embedding_and_stats(corrupted_image)
229
+
230
+ # Format statistics text
231
+ stats_text = format_stats_text(stats_dict)
232
+
233
+ # Classify image based on threshold
234
+ classification_text, is_corrupted = classify_image(error, threshold)
235
+
236
+ # Status message
237
+ status = f"""✅ Processing complete!
238
+ 📊 Model: {model_name}
239
+ 🔧 {corruption_info}
240
+ 📈 Reconstruction Error: {error:.6f}
241
+ 🎚️ Threshold: {threshold:.6f}
242
+ 🎯 Classification: {'CORRUPTED' if is_corrupted else 'CLEAN'}
243
+ 🧠 Embedding Dimension: {len(embedding) if embedding is not None else 'N/A'}
244
+ 💡 Higher error = more anomalous"""
245
+
246
+ return corrupted_image, plot_image, status, error, stats_text, classification_text
247
+
248
+ except Exception as e:
249
+ error_msg = f"❌ Error: {str(e)}"
250
+ return None, None, error_msg, 0.0, "", ""
251
+
252
+
253
+ # Create interface
254
+ def create_interface():
255
+ with gr.Blocks(title="Anomaly Detection Tester") as demo:
256
+ gr.Markdown("# 🔍 Federated Autoencoder for Kidney Stone Image Corruption Detection")
257
+ gr.Markdown("Upload an image, analyze its latent representation, and classify it as corrupted or clean using a threshold.")
258
+
259
+ with gr.Row():
260
+ with gr.Column(scale=1):
261
+ gr.Markdown("### ⚙️ Model & Corruption Settings")
262
+
263
+ model_dropdown = gr.Dropdown(
264
+ choices=list(models.keys()),
265
+ value="Daudon_MIX",
266
+ label="🤖 Select Model"
267
+ )
268
+
269
+ corruption_dropdown = gr.Dropdown(
270
+ choices=["none", "noise", "blur", "brightness", "contrast", "saturation", "random"],
271
+ value="none",
272
+ label="🔧 Corruption Type"
273
+ )
274
+
275
+ intensity_slider = gr.Slider(
276
+ minimum=0.1,
277
+ maximum=3.0,
278
+ value=1.0,
279
+ step=0.1,
280
+ label="💪 Corruption Intensity"
281
+ )
282
+
283
+ gr.Markdown("### 🎚️ Classification Settings")
284
+
285
+ threshold_slider = gr.Slider(
286
+ minimum=0.1,
287
+ maximum=3.0,
288
+ value=1.0,
289
+ step=0.1,
290
+ label="🎯 Anomaly Threshold (Reconstruction Error)"
291
+ )
292
+
293
+ gr.Markdown("### 📸 Image Input")
294
+
295
+ image_input = gr.Image(type="pil", label="Upload Image")
296
+
297
+ # Add examples section
298
+ gr.Markdown("### 📁 Example Images")
299
+
300
+ # You can specify your example image paths here
301
+ example_images = [
302
+ ["example_imgs/TypeIa_LaosN°15_Image21-25.png", "Clean Daudon MIX-Subtype_Ia"],
303
+ ["example_imgs/72222-SectionIVa+WK maj_0009-60.png", "Clean Daudon MIX-Subtype_IVa"],
304
+ ["example_imgs/TypeIVa2_N47583_Notteb-11.png", "Clean Daudon MIX-Subtype_IVa2"],
305
+ ["example_imgs/typIVc_IVbsectbis-43.png", "Clean Daudon MIX-Subtype_IVc"],
306
+ ["example_imgs/TypeIVd_Sect_LC3373-65.png", "Clean Daudon MIX-Subtype_IVd"],
307
+ ["example_imgs/Section_Va_72845-3-18.png", "Clean Daudon MIX-Subtype_Va"],
308
+ ]
309
+
310
+ examples_component = gr.Examples(
311
+ examples=example_images,
312
+ inputs=image_input,
313
+ label="Daudon MIX Example Clean Images",
314
+ examples_per_page=6,
315
+ cache_examples=False
316
+ )
317
+
318
+ process_btn = gr.Button("🚀 Analyze & Classify", variant="primary", size="lg")
319
+
320
+ with gr.Column(scale=1):
321
+ gr.Markdown("### 📊 Results")
322
+
323
+ status_output = gr.Textbox(label="📋 Status", lines=8)
324
+ error_output = gr.Number(label="📈 Reconstruction Error", precision=6)
325
+
326
+ corrupted_output = gr.Image(label="🔧 Input Image (Corrupted)")
327
+
328
+
329
+ with gr.Row():
330
+ embedding_plot = gr.Image(label="🧠 Embedding Analysis")
331
+
332
+ with gr.Row():
333
+ stats_output = gr.Textbox(label="📊 Embedding Statistics", lines=20)
334
+ classification_output = gr.Textbox(label="🎯 Classification Result", lines=15)
335
+
336
+ # Connect the button
337
+ process_btn.click(
338
+ fn=process_image,
339
+ inputs=[model_dropdown, image_input, corruption_dropdown, intensity_slider, threshold_slider],
340
+ outputs=[corrupted_output, embedding_plot, status_output, error_output, stats_output, classification_output]
341
+ )
342
+
343
+
344
+ return demo
345
+
346
+
347
+ if __name__ == "__main__":
348
+ print("🚀 Starting Embedding Analysis App...")
349
+
350
+ demo = create_interface()
351
+ demo.launch(
352
+ server_name="127.0.0.1",
353
+ server_port=7860,
354
+ share=False,
355
+ debug=False,
356
+ show_error=True
357
+ )
utils/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .data_utils import (
2
+ create_client_dataloaders,
3
+ create_global_test_loader,
4
+ ImageCorruption,
5
+ KidneyStoneDataset
6
+ )
7
+
8
+ __all__ = [
9
+ 'create_client_dataloaders',
10
+ 'create_global_test_loader',
11
+ 'ImageCorruption',
12
+ 'KidneyStoneDataset'
13
+ ]
utils/data_utils.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data utilities for federated autoencoder training
3
+ """
4
+
5
+ import os
6
+ import random
7
+ import numpy as np
8
+ from PIL import Image, ImageFilter, ImageEnhance
9
+ import torch
10
+ from torch.utils.data import Dataset, DataLoader
11
+ from torchvision import transforms
12
+ from sklearn.model_selection import train_test_split
13
+ from collections import Counter
14
+ import config
15
+
16
+
17
+ class ImageCorruption:
18
+ """Class to handle various image corruptions"""
19
+
20
+ def __init__(self, corruption_prob=0.1):
21
+ self.corruption_prob = corruption_prob
22
+
23
+ def gaussian_noise(self, image):
24
+ """Add Gaussian noise to image"""
25
+ if random.random() < self.corruption_prob:
26
+ noise = torch.randn_like(image) * 0.1
27
+ image = torch.clamp(image + noise, 0, 1)
28
+ return image
29
+
30
+ def salt_pepper_noise(self, image):
31
+ """Add salt and pepper noise"""
32
+ if random.random() < self.corruption_prob:
33
+ noise = torch.rand_like(image)
34
+ salt = noise > 0.95
35
+ pepper = noise < 0.05
36
+ image[salt] = 1.0
37
+ image[pepper] = 0.0
38
+ return image
39
+
40
+ def blur(self, image):
41
+ """Apply blur to image"""
42
+ if random.random() < self.corruption_prob:
43
+ # Convert to PIL for blur operation
44
+ if isinstance(image, torch.Tensor):
45
+ image_pil = transforms.ToPILImage()(image)
46
+ image_pil = image_pil.filter(ImageFilter.GaussianBlur(radius=random.uniform(0.5, 2.0)))
47
+ image = transforms.ToTensor()(image_pil)
48
+ return image
49
+
50
+ def brightness_change(self, image):
51
+ """Change brightness of image"""
52
+ if random.random() < self.corruption_prob:
53
+ factor = random.uniform(0.5, 1.5)
54
+ image = torch.clamp(image * factor, 0, 1)
55
+ return image
56
+
57
+ def contrast_change(self, image):
58
+ """Change contrast of image"""
59
+ if random.random() < self.corruption_prob:
60
+ mean = image.mean()
61
+ factor = random.uniform(0.5, 1.5)
62
+ image = torch.clamp((image - mean) * factor + mean, 0, 1)
63
+ return image
64
+
65
+ def apply_random_corruption(self, image):
66
+ """Apply a random corruption to the image"""
67
+ corruptions = [
68
+ self.gaussian_noise,
69
+ self.salt_pepper_noise,
70
+ self.blur,
71
+ self.brightness_change,
72
+ self.contrast_change
73
+ ]
74
+
75
+ corruption_func = random.choice(corruptions)
76
+ return corruption_func(image)
77
+
78
+
79
+ class KidneyStoneDataset(Dataset):
80
+ """Custom dataset for kidney stone images"""
81
+
82
+ def __init__(self, image_paths, labels, transform=None, corruption_prob=0.0):
83
+ self.image_paths = image_paths
84
+ self.labels = labels
85
+ self.transform = transform
86
+ self.corruption = ImageCorruption(corruption_prob)
87
+
88
+ def __len__(self):
89
+ return len(self.image_paths)
90
+
91
+ def __getitem__(self, idx):
92
+ image_path = self.image_paths[idx]
93
+ label = self.labels[idx]
94
+
95
+ # Load image
96
+ image = Image.open(image_path).convert('RGB')
97
+
98
+ if self.transform:
99
+ image = self.transform(image)
100
+
101
+ # Apply corruption if specified
102
+ if self.corruption.corruption_prob > 0:
103
+ image = self.corruption.apply_random_corruption(image)
104
+
105
+ return image, label
106
+
107
+
108
+ def load_dataset_paths(datasets=None, subversions=None):
109
+ """Load image paths and labels from specified datasets and subversions"""
110
+ all_paths = []
111
+ all_labels = []
112
+
113
+ # Use all datasets if none specified
114
+ if datasets is None:
115
+ datasets = config.DATASETS
116
+
117
+ # Use all subversions if none specified
118
+ if subversions is None:
119
+ subversions = config.SUBVERSIONS
120
+
121
+ for dataset_name in datasets:
122
+ dataset_path = os.path.join(config.DATA_ROOT, dataset_name)
123
+
124
+ if not os.path.exists(dataset_path):
125
+ print(f"Warning: Dataset path does not exist: {dataset_path}")
126
+ continue
127
+
128
+ for subversion in subversions:
129
+ subversion_path = os.path.join(dataset_path, subversion)
130
+
131
+ if not os.path.exists(subversion_path):
132
+ print(f"Warning: Subversion path does not exist: {subversion_path}")
133
+ continue
134
+
135
+ # Load training images (extract class from folder structure)
136
+ train_path = os.path.join(subversion_path, "train")
137
+ if os.path.exists(train_path):
138
+ # Get all class folders in train directory
139
+ class_folders = [d for d in os.listdir(train_path)
140
+ if os.path.isdir(os.path.join(train_path, d))]
141
+
142
+ for class_folder in class_folders:
143
+ class_path = os.path.join(train_path, class_folder)
144
+
145
+ # Load all images in this class folder
146
+ for img_file in os.listdir(class_path):
147
+ if img_file.lower().endswith(('.png', '.jpg', '.jpeg')):
148
+ img_path = os.path.join(class_path, img_file)
149
+ all_paths.append(img_path)
150
+ # Create label with class information: "subversion_class"
151
+ all_labels.append(f"{subversion}_{class_folder}")
152
+
153
+ # Load test images (extract class from folder structure)
154
+ test_path = os.path.join(subversion_path, "test")
155
+ if os.path.exists(test_path):
156
+ # Get all class folders in test directory
157
+ class_folders = [d for d in os.listdir(test_path)
158
+ if os.path.isdir(os.path.join(test_path, d))]
159
+
160
+ for class_folder in class_folders:
161
+ class_path = os.path.join(test_path, class_folder)
162
+
163
+ # Load all images in this class folder
164
+ for img_file in os.listdir(class_path):
165
+ if img_file.lower().endswith(('.png', '.jpg', '.jpeg')):
166
+ img_path = os.path.join(class_path, img_file)
167
+ all_paths.append(img_path)
168
+ # Create label with class information: "subversion_class"
169
+ all_labels.append(f"{subversion}_{class_folder}")
170
+
171
+ print(f"📊 Data loading summary:")
172
+ print(f" Total images: {len(all_paths)}")
173
+ print(f" Unique classes found: {len(set(all_labels))}")
174
+ print(f" Classes: {sorted(set(all_labels))}")
175
+
176
+ return all_paths, all_labels
177
+
178
+
179
+ def redistribute_data_evenly(image_paths, labels, num_clients):
180
+ """Redistribute data evenly among clients as fallback"""
181
+ total_samples = len(image_paths)
182
+ samples_per_client = total_samples // num_clients
183
+
184
+ # Shuffle data
185
+ combined = list(zip(image_paths, labels))
186
+ np.random.shuffle(combined)
187
+
188
+ client_datasets = []
189
+ for i in range(num_clients):
190
+ start_idx = i * samples_per_client
191
+ if i == num_clients - 1: # Last client gets remaining samples
192
+ end_idx = total_samples
193
+ else:
194
+ end_idx = (i + 1) * samples_per_client
195
+
196
+ client_data = combined[start_idx:end_idx]
197
+ if client_data:
198
+ client_paths, client_labels = zip(*client_data)
199
+ client_datasets.append((list(client_paths), list(client_labels)))
200
+ print(f"Client {i} redistributed with {len(client_paths)} samples")
201
+
202
+ return client_datasets
203
+
204
+
205
+ def create_non_iid_distribution(image_paths, labels, num_clients, alpha=0.5):
206
+ """Create non-IID data distribution using Dirichlet distribution"""
207
+
208
+ # Convert labels to numeric
209
+ unique_labels = list(set(labels))
210
+ label_to_idx = {label: idx for idx, label in enumerate(unique_labels)}
211
+ numeric_labels = [label_to_idx[label] for label in labels]
212
+
213
+ num_classes = len(unique_labels)
214
+
215
+ # Create Dirichlet distribution for each client
216
+ client_distributions = np.random.dirichlet([alpha] * num_classes, num_clients)
217
+
218
+ # Group data by class
219
+ class_indices = {i: [] for i in range(num_classes)}
220
+ for idx, label in enumerate(numeric_labels):
221
+ class_indices[label].append(idx)
222
+
223
+ # Distribute data to clients
224
+ client_data = [[] for _ in range(num_clients)]
225
+
226
+ for class_idx in range(num_classes):
227
+ class_data = class_indices[class_idx]
228
+ np.random.shuffle(class_data)
229
+
230
+ # Calculate how many samples each client gets from this class
231
+ total_samples = len(class_data)
232
+ client_samples = (client_distributions[:, class_idx] * total_samples).astype(int)
233
+
234
+ # Ensure we don't exceed total samples
235
+ if client_samples.sum() > total_samples:
236
+ excess = client_samples.sum() - total_samples
237
+ client_samples[-1] -= excess
238
+
239
+ # Distribute samples
240
+ start_idx = 0
241
+ for client_idx, num_samples in enumerate(client_samples):
242
+ if num_samples > 0:
243
+ end_idx = start_idx + num_samples
244
+ client_data[client_idx].extend(class_data[start_idx:end_idx])
245
+ start_idx = end_idx
246
+
247
+ # Convert indices back to paths and labels
248
+ client_datasets = []
249
+ for client_idx, client_indices in enumerate(client_data):
250
+ if len(client_indices) > 0: # Accept any client with at least some data
251
+ client_paths = [image_paths[i] for i in client_indices]
252
+ client_labels = [labels[i] for i in client_indices]
253
+ client_datasets.append((client_paths, client_labels))
254
+ print(f"Client {client_idx} will have {len(client_indices)} samples")
255
+ else:
256
+ print(f"Warning: Client {client_idx} has no samples assigned")
257
+
258
+ # If we don't have enough clients, redistribute the data more evenly
259
+ if len(client_datasets) < num_clients:
260
+ print(f"Warning: Only {len(client_datasets)} clients have sufficient data. Redistributing...")
261
+ return redistribute_data_evenly(image_paths, labels, num_clients)
262
+
263
+ return client_datasets
264
+
265
+
266
+ def safe_train_test_split(paths, labels, test_size=0.2, random_state=None):
267
+ """
268
+ Safely split data into train/test, handling classes with insufficient samples
269
+ """
270
+ # Count samples per class
271
+ class_counts = Counter(labels)
272
+
273
+ # Check if we can do stratified split
274
+ min_class_size = min(class_counts.values())
275
+ can_stratify = min_class_size >= 2
276
+
277
+ if can_stratify:
278
+ try:
279
+ return train_test_split(
280
+ paths, labels,
281
+ test_size=test_size,
282
+ random_state=random_state,
283
+ stratify=labels
284
+ )
285
+ except ValueError as e:
286
+ print(f" ⚠️ Stratified split failed: {e}")
287
+ can_stratify = False
288
+
289
+ if not can_stratify:
290
+ print(f" 📊 Using random split (some classes have <2 samples)")
291
+ print(f" 📈 Class distribution: {dict(class_counts)}")
292
+
293
+ # Use random split without stratification
294
+ return train_test_split(
295
+ paths, labels,
296
+ test_size=test_size,
297
+ random_state=random_state,
298
+ stratify=None
299
+ )
300
+
301
+
302
+ def get_data_transforms():
303
+ """Get data transformations for training and testing"""
304
+
305
+ train_transform = transforms.Compose([
306
+ transforms.Resize(config.IMAGE_SIZE),
307
+ transforms.RandomHorizontalFlip(p=0.5),
308
+ transforms.RandomRotation(10),
309
+ transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),
310
+ transforms.ToTensor(),
311
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
312
+ ])
313
+
314
+ test_transform = transforms.Compose([
315
+ transforms.Resize(config.IMAGE_SIZE),
316
+ transforms.ToTensor(),
317
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
318
+ ])
319
+
320
+ return train_transform, test_transform
321
+
322
+
323
+ def create_client_dataloaders(num_clients, corruption_prob=0.1, alpha=0.5, datasets=None, subversions=None):
324
+ """Create data loaders for all clients with non-IID distribution"""
325
+
326
+ # Load data from specified datasets and subversions
327
+ all_paths, all_labels = load_dataset_paths(datasets=datasets, subversions=subversions)
328
+
329
+ print(f"Total images loaded: {len(all_paths)}")
330
+ print(f"Unique labels: {set(all_labels)}")
331
+
332
+ if len(all_paths) == 0:
333
+ raise ValueError("No images found! Please check your dataset paths and subversions.")
334
+
335
+ # Create non-IID distribution
336
+ client_datasets = create_non_iid_distribution(all_paths, all_labels, num_clients, alpha)
337
+
338
+ print(f"Created {len(client_datasets)} client datasets")
339
+
340
+ # Get transforms
341
+ train_transform, test_transform = get_data_transforms()
342
+
343
+ # Create data loaders for each client
344
+ client_loaders = []
345
+
346
+ for i, (client_paths, client_labels) in enumerate(client_datasets):
347
+ print(f"Client {i}: {len(client_paths)} samples")
348
+
349
+ # Split into train/test for each client using safe splitting
350
+ train_paths, test_paths, train_labels, test_labels = safe_train_test_split(
351
+ client_paths, client_labels, test_size=0.2, random_state=config.SEED
352
+ )
353
+
354
+ # Create datasets
355
+ train_dataset = KidneyStoneDataset(
356
+ train_paths, train_labels,
357
+ transform=train_transform,
358
+ corruption_prob=corruption_prob
359
+ )
360
+
361
+ test_dataset = KidneyStoneDataset(
362
+ test_paths, test_labels,
363
+ transform=test_transform,
364
+ corruption_prob=0.0 # No corruption for test data
365
+ )
366
+
367
+ # Create data loaders
368
+ train_loader = DataLoader(
369
+ train_dataset,
370
+ batch_size=config.BATCH_SIZE,
371
+ shuffle=True,
372
+ num_workers=2
373
+ )
374
+
375
+ test_loader = DataLoader(
376
+ test_dataset,
377
+ batch_size=config.BATCH_SIZE,
378
+ shuffle=False,
379
+ num_workers=2
380
+ )
381
+
382
+ client_loaders.append((train_loader, test_loader))
383
+
384
+ return client_loaders
385
+
386
+
387
+ def create_global_test_loader(datasets=None, subversions=None):
388
+ """Create a global test loader for evaluation"""
389
+
390
+ # Load data from specified datasets and subversions
391
+ all_paths, all_labels = load_dataset_paths(datasets=datasets, subversions=subversions)
392
+
393
+ if len(all_paths) == 0:
394
+ raise ValueError("No images found for global test loader! Please check your dataset paths and subversions.")
395
+
396
+ # Use a subset for global testing with safe splitting
397
+ _, test_paths, _, test_labels = safe_train_test_split(
398
+ all_paths, all_labels, test_size=0.1, random_state=config.SEED
399
+ )
400
+
401
+ _, test_transform = get_data_transforms()
402
+
403
+ test_dataset = KidneyStoneDataset(
404
+ test_paths, test_labels,
405
+ transform=test_transform,
406
+ corruption_prob=0.0
407
+ )
408
+
409
+ test_loader = DataLoader(
410
+ test_dataset,
411
+ batch_size=config.BATCH_SIZE,
412
+ shuffle=False,
413
+ num_workers=2
414
+ )
415
+
416
+ return test_loader
utils/metrics.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Metrics utilities for federated autoencoder evaluation
3
+ """
4
+
5
+ import torch
6
+ import numpy as np
7
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
8
+ from sklearn.metrics import classification_report, silhouette_score
9
+ from sklearn.cluster import KMeans
10
+ from sklearn.preprocessing import LabelEncoder
11
+ import torch.nn.functional as F
12
+
13
+
14
+ def calculate_reconstruction_metrics(model, data_loader, device):
15
+ """
16
+ Calculate reconstruction-based metrics for autoencoder
17
+
18
+ Args:
19
+ model: Trained autoencoder model
20
+ data_loader: DataLoader with test data
21
+ device: Device to run evaluation on
22
+
23
+ Returns:
24
+ dict: Dictionary containing reconstruction metrics
25
+ """
26
+ model.eval()
27
+
28
+ reconstruction_errors = []
29
+ total_loss = 0.0
30
+ num_samples = 0
31
+
32
+ criterion = torch.nn.MSELoss(reduction='none')
33
+
34
+ with torch.no_grad():
35
+ for data, labels in data_loader:
36
+ data = data.to(device)
37
+
38
+ # Forward pass
39
+ reconstructed, latent = model(data)
40
+
41
+ # Calculate per-sample reconstruction error
42
+ batch_errors = criterion(reconstructed, data).view(data.size(0), -1).mean(dim=1)
43
+ reconstruction_errors.extend(batch_errors.cpu().numpy())
44
+
45
+ # Calculate total loss
46
+ total_loss += F.mse_loss(reconstructed, data).item() * data.size(0)
47
+ num_samples += data.size(0)
48
+
49
+ reconstruction_errors = np.array(reconstruction_errors)
50
+
51
+ # Calculate reconstruction statistics
52
+ avg_loss = total_loss / num_samples
53
+ avg_reconstruction_error = np.mean(reconstruction_errors)
54
+ std_reconstruction_error = np.std(reconstruction_errors)
55
+ median_reconstruction_error = np.median(reconstruction_errors)
56
+
57
+ # Calculate reconstruction quality metrics (lower is better)
58
+ # Use percentiles to define "good" vs "poor" reconstruction
59
+ percentile_25 = np.percentile(reconstruction_errors, 25)
60
+ percentile_75 = np.percentile(reconstruction_errors, 75)
61
+
62
+ # Define good reconstruction as bottom 25% of errors
63
+ good_reconstruction = reconstruction_errors <= percentile_25
64
+ poor_reconstruction = reconstruction_errors >= percentile_75
65
+
66
+ # For autoencoder evaluation, we'll use a more meaningful approach:
67
+ # Compare reconstruction quality across different error thresholds
68
+
69
+ # Method 1: Use median as threshold (more stable than mean)
70
+ median_threshold = np.median(reconstruction_errors)
71
+ better_than_median = (reconstruction_errors <= median_threshold).astype(int)
72
+
73
+ # Method 2: Use a stricter threshold (25th percentile) for "good" reconstructions
74
+ strict_threshold = np.percentile(reconstruction_errors, 25)
75
+ high_quality = (reconstruction_errors <= strict_threshold).astype(int)
76
+
77
+ # Calculate "precision" as: how many predicted good are actually good
78
+ # This is more like "consistency" - if we predict good, how often is it actually good?
79
+
80
+ # For binary classification metrics, we need to define what we're classifying
81
+ # Let's classify: "Is this reconstruction better than average?"
82
+
83
+ # Ground truth: better than median (50% of samples)
84
+ true_better_than_median = better_than_median
85
+
86
+ # Prediction: better than 40th percentile (slightly more lenient)
87
+ prediction_threshold = np.percentile(reconstruction_errors, 40)
88
+ predicted_better = (reconstruction_errors <= prediction_threshold).astype(int)
89
+
90
+ # Calculate metrics - but note these are somewhat artificial for autoencoders
91
+ accuracy = accuracy_score(true_better_than_median, predicted_better)
92
+ precision = precision_score(true_better_than_median, predicted_better, average='binary', zero_division=0)
93
+ recall = recall_score(true_better_than_median, predicted_better, average='binary', zero_division=0)
94
+ f1 = f1_score(true_better_than_median, predicted_better, average='binary', zero_division=0)
95
+
96
+ # Add a note about what these metrics mean
97
+ classification_note = (
98
+ "Note: Classification metrics compare 40th vs 50th percentile thresholds. "
99
+ "Perfect scores may indicate threshold alignment rather than model quality."
100
+ )
101
+
102
+ return {
103
+ 'loss': avg_loss,
104
+ 'reconstruction_error': avg_reconstruction_error,
105
+ 'reconstruction_std': std_reconstruction_error,
106
+ 'reconstruction_median': median_reconstruction_error,
107
+ 'reconstruction_25th': percentile_25,
108
+ 'reconstruction_75th': percentile_75,
109
+ 'accuracy': accuracy,
110
+ 'precision': precision,
111
+ 'recall': recall,
112
+ 'f1_score': f1,
113
+ 'num_samples': num_samples,
114
+ 'good_reconstructions': np.sum(good_reconstruction),
115
+ 'poor_reconstructions': np.sum(poor_reconstruction),
116
+ 'classification_note': classification_note,
117
+ 'better_than_median_count': np.sum(better_than_median),
118
+ 'high_quality_count': np.sum(high_quality)
119
+ }
120
+
121
+
122
+ def calculate_latent_classification_metrics(model, data_loader, device):
123
+ """
124
+ Calculate classification metrics using latent space representations
125
+
126
+ Args:
127
+ model: Trained autoencoder model
128
+ data_loader: DataLoader with test data
129
+ device: Device to run evaluation on
130
+
131
+ Returns:
132
+ dict: Dictionary containing latent space metrics
133
+ """
134
+ model.eval()
135
+
136
+ latent_features = []
137
+ true_labels = []
138
+
139
+ with torch.no_grad():
140
+ for data, labels in data_loader:
141
+ data = data.to(device)
142
+
143
+ # Get latent representations
144
+ _, latent = model(data)
145
+ latent_features.append(latent.cpu().numpy())
146
+ true_labels.extend(labels)
147
+
148
+ # Combine all latent features
149
+ latent_features = np.vstack(latent_features)
150
+
151
+ # Encode string labels to numeric
152
+ label_encoder = LabelEncoder()
153
+ numeric_labels = label_encoder.fit_transform(true_labels)
154
+ unique_labels = np.unique(numeric_labels)
155
+ n_classes = len(unique_labels)
156
+
157
+ print(f" 🔍 Latent analysis: {n_classes} unique classes found")
158
+ print(f" 📊 Class distribution: {dict(zip(label_encoder.classes_, np.bincount(numeric_labels)))}")
159
+
160
+ if n_classes == 1:
161
+ # Single class case - no meaningful classification possible
162
+ return {
163
+ 'latent_accuracy': 0.0, # No classification possible
164
+ 'latent_precision': 0.0,
165
+ 'latent_recall': 0.0,
166
+ 'latent_f1_score': 0.0,
167
+ 'silhouette_score': 0.0, # No clustering possible
168
+ 'n_clusters': n_classes,
169
+ 'latent_dim': latent_features.shape[1],
170
+ 'cluster_quality': 'single_class'
171
+ }
172
+
173
+ # Perform clustering
174
+ try:
175
+ # Use the actual number of classes for clustering
176
+ kmeans = KMeans(n_clusters=n_classes, random_state=42, n_init=10)
177
+ cluster_predictions = kmeans.fit_predict(latent_features)
178
+
179
+ # Calculate silhouette score for cluster quality
180
+ if n_classes > 1 and len(set(cluster_predictions)) > 1:
181
+ silhouette = silhouette_score(latent_features, cluster_predictions)
182
+ else:
183
+ silhouette = 0.0
184
+
185
+ # For meaningful classification metrics, we need to align cluster labels with true labels
186
+ # This is a complex problem, so we'll use a simpler approach:
187
+ # Calculate how well the clustering separates the true classes
188
+
189
+ # Method 1: Direct comparison (may not be meaningful due to label permutation)
190
+ accuracy_direct = accuracy_score(numeric_labels, cluster_predictions)
191
+
192
+ # Method 2: Best possible alignment between clusters and true labels
193
+ try:
194
+ from scipy.optimize import linear_sum_assignment
195
+ from sklearn.metrics import confusion_matrix
196
+
197
+ # Create confusion matrix
198
+ cm = confusion_matrix(numeric_labels, cluster_predictions)
199
+
200
+ # Find best assignment using Hungarian algorithm
201
+ if cm.shape[0] == cm.shape[1]: # Same number of clusters and classes
202
+ row_ind, col_ind = linear_sum_assignment(-cm) # Negative for maximization
203
+ aligned_predictions = np.zeros_like(cluster_predictions)
204
+ for i, j in zip(row_ind, col_ind):
205
+ aligned_predictions[cluster_predictions == j] = i
206
+
207
+ # Calculate metrics with aligned labels
208
+ accuracy = accuracy_score(numeric_labels, aligned_predictions)
209
+ precision = precision_score(numeric_labels, aligned_predictions, average='weighted', zero_division=0)
210
+ recall = recall_score(numeric_labels, aligned_predictions, average='weighted', zero_division=0)
211
+ f1 = f1_score(numeric_labels, aligned_predictions, average='weighted', zero_division=0)
212
+ cluster_quality = 'aligned'
213
+ else:
214
+ # Different number of clusters and classes - use direct comparison
215
+ accuracy = accuracy_direct
216
+ precision = precision_score(numeric_labels, cluster_predictions, average='weighted', zero_division=0)
217
+ recall = recall_score(numeric_labels, cluster_predictions, average='weighted', zero_division=0)
218
+ f1 = f1_score(numeric_labels, cluster_predictions, average='weighted', zero_division=0)
219
+ cluster_quality = 'unaligned'
220
+ except ImportError:
221
+ print(f" ⚠️ scipy not available, using direct comparison")
222
+ # Fallback to direct comparison without alignment
223
+ accuracy = accuracy_direct
224
+ precision = precision_score(numeric_labels, cluster_predictions, average='weighted', zero_division=0)
225
+ recall = recall_score(numeric_labels, cluster_predictions, average='weighted', zero_division=0)
226
+ f1 = f1_score(numeric_labels, cluster_predictions, average='weighted', zero_division=0)
227
+ cluster_quality = 'direct'
228
+
229
+ except Exception as e:
230
+ print(f" ⚠️ Clustering failed: {e}")
231
+ accuracy = precision = recall = f1 = silhouette = 0.0
232
+ cluster_quality = 'failed'
233
+
234
+ return {
235
+ 'latent_accuracy': accuracy,
236
+ 'latent_precision': precision,
237
+ 'latent_recall': recall,
238
+ 'latent_f1_score': f1,
239
+ 'silhouette_score': silhouette,
240
+ 'n_clusters': n_classes,
241
+ 'latent_dim': latent_features.shape[1],
242
+ 'cluster_quality': cluster_quality
243
+ }
244
+
245
+
246
+ def calculate_comprehensive_metrics(model, data_loader, device):
247
+ """
248
+ Calculate comprehensive metrics for autoencoder evaluation
249
+
250
+ Args:
251
+ model: Trained autoencoder model
252
+ data_loader: DataLoader with test data
253
+ device: Device to run evaluation on
254
+
255
+ Returns:
256
+ dict: Dictionary containing all metrics
257
+ """
258
+ print(f" 🔄 Calculating reconstruction metrics...")
259
+ recon_metrics = calculate_reconstruction_metrics(model, data_loader, device)
260
+
261
+ print(f" 🧠 Calculating latent space metrics...")
262
+ latent_metrics = calculate_latent_classification_metrics(model, data_loader, device)
263
+
264
+ # Combine all metrics
265
+ comprehensive_metrics = {
266
+ **recon_metrics,
267
+ **latent_metrics
268
+ }
269
+
270
+ return comprehensive_metrics
271
+
272
+
273
+ def print_metrics_summary(metrics, subversion_name):
274
+ """Print a formatted summary of metrics"""
275
+ print(f"\n📊 Metrics Summary for {subversion_name}:")
276
+ print("=" * 60)
277
+
278
+ # Reconstruction metrics
279
+ print(f"🔄 Reconstruction Loss: {metrics['loss']:.6f}")
280
+ print(f"📏 Reconstruction Error: {metrics['reconstruction_error']:.6f} ± {metrics['reconstruction_std']:.6f}")
281
+ print(f"📊 Reconstruction Median: {metrics['reconstruction_median']:.6f}")
282
+ print(f"📈 25th/75th Percentile: {metrics['reconstruction_25th']:.6f} / {metrics['reconstruction_75th']:.6f}")
283
+ print(f"✅ Good Reconstructions: {metrics['good_reconstructions']}/{metrics['num_samples']} ({100*metrics['good_reconstructions']/metrics['num_samples']:.1f}%)")
284
+ print(f"❌ Poor Reconstructions: {metrics['poor_reconstructions']}/{metrics['num_samples']} ({100*metrics['poor_reconstructions']/metrics['num_samples']:.1f}%)")
285
+ print(f"🎯 Better than Median: {metrics['better_than_median_count']}/{metrics['num_samples']} ({100*metrics['better_than_median_count']/metrics['num_samples']:.1f}%)")
286
+ print(f"⭐ High Quality (top 25%): {metrics['high_quality_count']}/{metrics['num_samples']} ({100*metrics['high_quality_count']/metrics['num_samples']:.1f}%)")
287
+
288
+ # Classification metrics with explanation
289
+ print(f"\n🎯 Reconstruction Classification Metrics:")
290
+ print(f" Accuracy: {metrics['accuracy']:.4f}")
291
+ print(f" Precision: {metrics['precision']:.4f}")
292
+ print(f" Recall: {metrics['recall']:.4f}")
293
+ print(f" F1-Score: {metrics['f1_score']:.4f}")
294
+ print(f" ℹ️ {metrics['classification_note']}")
295
+
296
+ # Latent space metrics
297
+ print(f"\n🧠 Latent Space Analysis:")
298
+ print(f" Latent Accuracy: {metrics['latent_accuracy']:.4f}")
299
+ print(f" Latent Precision: {metrics['latent_precision']:.4f}")
300
+ print(f" Latent Recall: {metrics['latent_recall']:.4f}")
301
+ print(f" Latent F1-Score: {metrics['latent_f1_score']:.4f}")
302
+ print(f" Silhouette Score: {metrics['silhouette_score']:.4f}")
303
+ print(f" Clusters Found: {metrics['n_clusters']}")
304
+ print(f" Latent Dimension: {metrics['latent_dim']}")
305
+ print(f" Cluster Quality: {metrics['cluster_quality']}")
306
+
307
+ # Interpretation guide
308
+ print(f"\n📖 Interpretation Guide:")
309
+ print(f" • Reconstruction Loss: Lower = better image reconstruction")
310
+ print(f" • Latent Accuracy: How well clustering separates kidney stone classes")
311
+ print(f" • Silhouette Score: Quality of latent space clustering (higher = better)")
312
+ print(f" • Perfect precision (1.0) in reconstruction metrics may indicate")
313
+ print(f" threshold alignment rather than exceptional model performance")
314
+
315
+ print(f"\n📦 Total Samples: {metrics['num_samples']}")
316
+ print("=" * 60)