Ivanrs commited on
Commit
7a224e8
·
verified ·
1 Parent(s): 6cd7429

Update utils/data_utils.py

Browse files
Files changed (1) hide show
  1. utils/data_utils.py +289 -0
utils/data_utils.py CHANGED
@@ -105,6 +105,200 @@ class KidneyStoneDataset(Dataset):
105
  return image, label
106
 
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  def get_data_transforms():
109
  """Get data transformations for training and testing"""
110
 
@@ -125,3 +319,98 @@ def get_data_transforms():
125
 
126
  return train_transform, test_transform
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  return image, label
106
 
107
 
108
+ def load_dataset_paths(datasets=None, subversions=None):
109
+ """Load image paths and labels from specified datasets and subversions"""
110
+ all_paths = []
111
+ all_labels = []
112
+
113
+ # Use all datasets if none specified
114
+ if datasets is None:
115
+ datasets = config.DATASETS
116
+
117
+ # Use all subversions if none specified
118
+ if subversions is None:
119
+ subversions = config.SUBVERSIONS
120
+
121
+ for dataset_name in datasets:
122
+ dataset_path = os.path.join(config.DATA_ROOT, dataset_name)
123
+
124
+ if not os.path.exists(dataset_path):
125
+ print(f"Warning: Dataset path does not exist: {dataset_path}")
126
+ continue
127
+
128
+ for subversion in subversions:
129
+ subversion_path = os.path.join(dataset_path, subversion)
130
+
131
+ if not os.path.exists(subversion_path):
132
+ print(f"Warning: Subversion path does not exist: {subversion_path}")
133
+ continue
134
+
135
+ # Load training images (extract class from folder structure)
136
+ train_path = os.path.join(subversion_path, "train")
137
+ if os.path.exists(train_path):
138
+ # Get all class folders in train directory
139
+ class_folders = [d for d in os.listdir(train_path)
140
+ if os.path.isdir(os.path.join(train_path, d))]
141
+
142
+ for class_folder in class_folders:
143
+ class_path = os.path.join(train_path, class_folder)
144
+
145
+ # Load all images in this class folder
146
+ for img_file in os.listdir(class_path):
147
+ if img_file.lower().endswith(('.png', '.jpg', '.jpeg')):
148
+ img_path = os.path.join(class_path, img_file)
149
+ all_paths.append(img_path)
150
+ # Create label with class information: "subversion_class"
151
+ all_labels.append(f"{subversion}_{class_folder}")
152
+
153
+ # Load test images (extract class from folder structure)
154
+ test_path = os.path.join(subversion_path, "test")
155
+ if os.path.exists(test_path):
156
+ # Get all class folders in test directory
157
+ class_folders = [d for d in os.listdir(test_path)
158
+ if os.path.isdir(os.path.join(test_path, d))]
159
+
160
+ for class_folder in class_folders:
161
+ class_path = os.path.join(test_path, class_folder)
162
+
163
+ # Load all images in this class folder
164
+ for img_file in os.listdir(class_path):
165
+ if img_file.lower().endswith(('.png', '.jpg', '.jpeg')):
166
+ img_path = os.path.join(class_path, img_file)
167
+ all_paths.append(img_path)
168
+ # Create label with class information: "subversion_class"
169
+ all_labels.append(f"{subversion}_{class_folder}")
170
+
171
+ print(f"📊 Data loading summary:")
172
+ print(f" Total images: {len(all_paths)}")
173
+ print(f" Unique classes found: {len(set(all_labels))}")
174
+ print(f" Classes: {sorted(set(all_labels))}")
175
+
176
+ return all_paths, all_labels
177
+
178
+
179
+ def redistribute_data_evenly(image_paths, labels, num_clients):
180
+ """Redistribute data evenly among clients as fallback"""
181
+ total_samples = len(image_paths)
182
+ samples_per_client = total_samples // num_clients
183
+
184
+ # Shuffle data
185
+ combined = list(zip(image_paths, labels))
186
+ np.random.shuffle(combined)
187
+
188
+ client_datasets = []
189
+ for i in range(num_clients):
190
+ start_idx = i * samples_per_client
191
+ if i == num_clients - 1: # Last client gets remaining samples
192
+ end_idx = total_samples
193
+ else:
194
+ end_idx = (i + 1) * samples_per_client
195
+
196
+ client_data = combined[start_idx:end_idx]
197
+ if client_data:
198
+ client_paths, client_labels = zip(*client_data)
199
+ client_datasets.append((list(client_paths), list(client_labels)))
200
+ print(f"Client {i} redistributed with {len(client_paths)} samples")
201
+
202
+ return client_datasets
203
+
204
+
205
+ def create_non_iid_distribution(image_paths, labels, num_clients, alpha=0.5):
206
+ """Create non-IID data distribution using Dirichlet distribution"""
207
+
208
+ # Convert labels to numeric
209
+ unique_labels = list(set(labels))
210
+ label_to_idx = {label: idx for idx, label in enumerate(unique_labels)}
211
+ numeric_labels = [label_to_idx[label] for label in labels]
212
+
213
+ num_classes = len(unique_labels)
214
+
215
+ # Create Dirichlet distribution for each client
216
+ client_distributions = np.random.dirichlet([alpha] * num_classes, num_clients)
217
+
218
+ # Group data by class
219
+ class_indices = {i: [] for i in range(num_classes)}
220
+ for idx, label in enumerate(numeric_labels):
221
+ class_indices[label].append(idx)
222
+
223
+ # Distribute data to clients
224
+ client_data = [[] for _ in range(num_clients)]
225
+
226
+ for class_idx in range(num_classes):
227
+ class_data = class_indices[class_idx]
228
+ np.random.shuffle(class_data)
229
+
230
+ # Calculate how many samples each client gets from this class
231
+ total_samples = len(class_data)
232
+ client_samples = (client_distributions[:, class_idx] * total_samples).astype(int)
233
+
234
+ # Ensure we don't exceed total samples
235
+ if client_samples.sum() > total_samples:
236
+ excess = client_samples.sum() - total_samples
237
+ client_samples[-1] -= excess
238
+
239
+ # Distribute samples
240
+ start_idx = 0
241
+ for client_idx, num_samples in enumerate(client_samples):
242
+ if num_samples > 0:
243
+ end_idx = start_idx + num_samples
244
+ client_data[client_idx].extend(class_data[start_idx:end_idx])
245
+ start_idx = end_idx
246
+
247
+ # Convert indices back to paths and labels
248
+ client_datasets = []
249
+ for client_idx, client_indices in enumerate(client_data):
250
+ if len(client_indices) > 0: # Accept any client with at least some data
251
+ client_paths = [image_paths[i] for i in client_indices]
252
+ client_labels = [labels[i] for i in client_indices]
253
+ client_datasets.append((client_paths, client_labels))
254
+ print(f"Client {client_idx} will have {len(client_indices)} samples")
255
+ else:
256
+ print(f"Warning: Client {client_idx} has no samples assigned")
257
+
258
+ # If we don't have enough clients, redistribute the data more evenly
259
+ if len(client_datasets) < num_clients:
260
+ print(f"Warning: Only {len(client_datasets)} clients have sufficient data. Redistributing...")
261
+ return redistribute_data_evenly(image_paths, labels, num_clients)
262
+
263
+ return client_datasets
264
+
265
+
266
+ def safe_train_test_split(paths, labels, test_size=0.2, random_state=None):
267
+ """
268
+ Safely split data into train/test, handling classes with insufficient samples
269
+ """
270
+ # Count samples per class
271
+ class_counts = Counter(labels)
272
+
273
+ # Check if we can do stratified split
274
+ min_class_size = min(class_counts.values())
275
+ can_stratify = min_class_size >= 2
276
+
277
+ if can_stratify:
278
+ try:
279
+ return train_test_split(
280
+ paths, labels,
281
+ test_size=test_size,
282
+ random_state=random_state,
283
+ stratify=labels
284
+ )
285
+ except ValueError as e:
286
+ print(f" ⚠️ Stratified split failed: {e}")
287
+ can_stratify = False
288
+
289
+ if not can_stratify:
290
+ print(f" 📊 Using random split (some classes have <2 samples)")
291
+ print(f" 📈 Class distribution: {dict(class_counts)}")
292
+
293
+ # Use random split without stratification
294
+ return train_test_split(
295
+ paths, labels,
296
+ test_size=test_size,
297
+ random_state=random_state,
298
+ stratify=None
299
+ )
300
+
301
+
302
  def get_data_transforms():
303
  """Get data transformations for training and testing"""
304
 
 
319
 
320
  return train_transform, test_transform
321
 
322
+
323
+ def create_client_dataloaders(num_clients, corruption_prob=0.1, alpha=0.5, datasets=None, subversions=None):
324
+ """Create data loaders for all clients with non-IID distribution"""
325
+
326
+ # Load data from specified datasets and subversions
327
+ all_paths, all_labels = load_dataset_paths(datasets=datasets, subversions=subversions)
328
+
329
+ print(f"Total images loaded: {len(all_paths)}")
330
+ print(f"Unique labels: {set(all_labels)}")
331
+
332
+ if len(all_paths) == 0:
333
+ raise ValueError("No images found! Please check your dataset paths and subversions.")
334
+
335
+ # Create non-IID distribution
336
+ client_datasets = create_non_iid_distribution(all_paths, all_labels, num_clients, alpha)
337
+
338
+ print(f"Created {len(client_datasets)} client datasets")
339
+
340
+ # Get transforms
341
+ train_transform, test_transform = get_data_transforms()
342
+
343
+ # Create data loaders for each client
344
+ client_loaders = []
345
+
346
+ for i, (client_paths, client_labels) in enumerate(client_datasets):
347
+ print(f"Client {i}: {len(client_paths)} samples")
348
+
349
+ # Split into train/test for each client using safe splitting
350
+ train_paths, test_paths, train_labels, test_labels = safe_train_test_split(
351
+ client_paths, client_labels, test_size=0.2, random_state=config.SEED
352
+ )
353
+
354
+ # Create datasets
355
+ train_dataset = KidneyStoneDataset(
356
+ train_paths, train_labels,
357
+ transform=train_transform,
358
+ corruption_prob=corruption_prob
359
+ )
360
+
361
+ test_dataset = KidneyStoneDataset(
362
+ test_paths, test_labels,
363
+ transform=test_transform,
364
+ corruption_prob=0.0 # No corruption for test data
365
+ )
366
+
367
+ # Create data loaders
368
+ train_loader = DataLoader(
369
+ train_dataset,
370
+ batch_size=config.BATCH_SIZE,
371
+ shuffle=True,
372
+ num_workers=2
373
+ )
374
+
375
+ test_loader = DataLoader(
376
+ test_dataset,
377
+ batch_size=config.BATCH_SIZE,
378
+ shuffle=False,
379
+ num_workers=2
380
+ )
381
+
382
+ client_loaders.append((train_loader, test_loader))
383
+
384
+ return client_loaders
385
+
386
+
387
+ def create_global_test_loader(datasets=None, subversions=None):
388
+ """Create a global test loader for evaluation"""
389
+
390
+ # Load data from specified datasets and subversions
391
+ all_paths, all_labels = load_dataset_paths(datasets=datasets, subversions=subversions)
392
+
393
+ if len(all_paths) == 0:
394
+ raise ValueError("No images found for global test loader! Please check your dataset paths and subversions.")
395
+
396
+ # Use a subset for global testing with safe splitting
397
+ _, test_paths, _, test_labels = safe_train_test_split(
398
+ all_paths, all_labels, test_size=0.1, random_state=config.SEED
399
+ )
400
+
401
+ _, test_transform = get_data_transforms()
402
+
403
+ test_dataset = KidneyStoneDataset(
404
+ test_paths, test_labels,
405
+ transform=test_transform,
406
+ corruption_prob=0.0
407
+ )
408
+
409
+ test_loader = DataLoader(
410
+ test_dataset,
411
+ batch_size=config.BATCH_SIZE,
412
+ shuffle=False,
413
+ num_workers=2
414
+ )
415
+
416
+ return test_loader