Update utils/data_utils.py
Browse files- utils/data_utils.py +289 -0
utils/data_utils.py
CHANGED
@@ -105,6 +105,200 @@ class KidneyStoneDataset(Dataset):
|
|
105 |
return image, label
|
106 |
|
107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
def get_data_transforms():
|
109 |
"""Get data transformations for training and testing"""
|
110 |
|
@@ -125,3 +319,98 @@ def get_data_transforms():
|
|
125 |
|
126 |
return train_transform, test_transform
|
127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
return image, label
|
106 |
|
107 |
|
108 |
+
def load_dataset_paths(datasets=None, subversions=None):
|
109 |
+
"""Load image paths and labels from specified datasets and subversions"""
|
110 |
+
all_paths = []
|
111 |
+
all_labels = []
|
112 |
+
|
113 |
+
# Use all datasets if none specified
|
114 |
+
if datasets is None:
|
115 |
+
datasets = config.DATASETS
|
116 |
+
|
117 |
+
# Use all subversions if none specified
|
118 |
+
if subversions is None:
|
119 |
+
subversions = config.SUBVERSIONS
|
120 |
+
|
121 |
+
for dataset_name in datasets:
|
122 |
+
dataset_path = os.path.join(config.DATA_ROOT, dataset_name)
|
123 |
+
|
124 |
+
if not os.path.exists(dataset_path):
|
125 |
+
print(f"Warning: Dataset path does not exist: {dataset_path}")
|
126 |
+
continue
|
127 |
+
|
128 |
+
for subversion in subversions:
|
129 |
+
subversion_path = os.path.join(dataset_path, subversion)
|
130 |
+
|
131 |
+
if not os.path.exists(subversion_path):
|
132 |
+
print(f"Warning: Subversion path does not exist: {subversion_path}")
|
133 |
+
continue
|
134 |
+
|
135 |
+
# Load training images (extract class from folder structure)
|
136 |
+
train_path = os.path.join(subversion_path, "train")
|
137 |
+
if os.path.exists(train_path):
|
138 |
+
# Get all class folders in train directory
|
139 |
+
class_folders = [d for d in os.listdir(train_path)
|
140 |
+
if os.path.isdir(os.path.join(train_path, d))]
|
141 |
+
|
142 |
+
for class_folder in class_folders:
|
143 |
+
class_path = os.path.join(train_path, class_folder)
|
144 |
+
|
145 |
+
# Load all images in this class folder
|
146 |
+
for img_file in os.listdir(class_path):
|
147 |
+
if img_file.lower().endswith(('.png', '.jpg', '.jpeg')):
|
148 |
+
img_path = os.path.join(class_path, img_file)
|
149 |
+
all_paths.append(img_path)
|
150 |
+
# Create label with class information: "subversion_class"
|
151 |
+
all_labels.append(f"{subversion}_{class_folder}")
|
152 |
+
|
153 |
+
# Load test images (extract class from folder structure)
|
154 |
+
test_path = os.path.join(subversion_path, "test")
|
155 |
+
if os.path.exists(test_path):
|
156 |
+
# Get all class folders in test directory
|
157 |
+
class_folders = [d for d in os.listdir(test_path)
|
158 |
+
if os.path.isdir(os.path.join(test_path, d))]
|
159 |
+
|
160 |
+
for class_folder in class_folders:
|
161 |
+
class_path = os.path.join(test_path, class_folder)
|
162 |
+
|
163 |
+
# Load all images in this class folder
|
164 |
+
for img_file in os.listdir(class_path):
|
165 |
+
if img_file.lower().endswith(('.png', '.jpg', '.jpeg')):
|
166 |
+
img_path = os.path.join(class_path, img_file)
|
167 |
+
all_paths.append(img_path)
|
168 |
+
# Create label with class information: "subversion_class"
|
169 |
+
all_labels.append(f"{subversion}_{class_folder}")
|
170 |
+
|
171 |
+
print(f"📊 Data loading summary:")
|
172 |
+
print(f" Total images: {len(all_paths)}")
|
173 |
+
print(f" Unique classes found: {len(set(all_labels))}")
|
174 |
+
print(f" Classes: {sorted(set(all_labels))}")
|
175 |
+
|
176 |
+
return all_paths, all_labels
|
177 |
+
|
178 |
+
|
179 |
+
def redistribute_data_evenly(image_paths, labels, num_clients):
|
180 |
+
"""Redistribute data evenly among clients as fallback"""
|
181 |
+
total_samples = len(image_paths)
|
182 |
+
samples_per_client = total_samples // num_clients
|
183 |
+
|
184 |
+
# Shuffle data
|
185 |
+
combined = list(zip(image_paths, labels))
|
186 |
+
np.random.shuffle(combined)
|
187 |
+
|
188 |
+
client_datasets = []
|
189 |
+
for i in range(num_clients):
|
190 |
+
start_idx = i * samples_per_client
|
191 |
+
if i == num_clients - 1: # Last client gets remaining samples
|
192 |
+
end_idx = total_samples
|
193 |
+
else:
|
194 |
+
end_idx = (i + 1) * samples_per_client
|
195 |
+
|
196 |
+
client_data = combined[start_idx:end_idx]
|
197 |
+
if client_data:
|
198 |
+
client_paths, client_labels = zip(*client_data)
|
199 |
+
client_datasets.append((list(client_paths), list(client_labels)))
|
200 |
+
print(f"Client {i} redistributed with {len(client_paths)} samples")
|
201 |
+
|
202 |
+
return client_datasets
|
203 |
+
|
204 |
+
|
205 |
+
def create_non_iid_distribution(image_paths, labels, num_clients, alpha=0.5):
|
206 |
+
"""Create non-IID data distribution using Dirichlet distribution"""
|
207 |
+
|
208 |
+
# Convert labels to numeric
|
209 |
+
unique_labels = list(set(labels))
|
210 |
+
label_to_idx = {label: idx for idx, label in enumerate(unique_labels)}
|
211 |
+
numeric_labels = [label_to_idx[label] for label in labels]
|
212 |
+
|
213 |
+
num_classes = len(unique_labels)
|
214 |
+
|
215 |
+
# Create Dirichlet distribution for each client
|
216 |
+
client_distributions = np.random.dirichlet([alpha] * num_classes, num_clients)
|
217 |
+
|
218 |
+
# Group data by class
|
219 |
+
class_indices = {i: [] for i in range(num_classes)}
|
220 |
+
for idx, label in enumerate(numeric_labels):
|
221 |
+
class_indices[label].append(idx)
|
222 |
+
|
223 |
+
# Distribute data to clients
|
224 |
+
client_data = [[] for _ in range(num_clients)]
|
225 |
+
|
226 |
+
for class_idx in range(num_classes):
|
227 |
+
class_data = class_indices[class_idx]
|
228 |
+
np.random.shuffle(class_data)
|
229 |
+
|
230 |
+
# Calculate how many samples each client gets from this class
|
231 |
+
total_samples = len(class_data)
|
232 |
+
client_samples = (client_distributions[:, class_idx] * total_samples).astype(int)
|
233 |
+
|
234 |
+
# Ensure we don't exceed total samples
|
235 |
+
if client_samples.sum() > total_samples:
|
236 |
+
excess = client_samples.sum() - total_samples
|
237 |
+
client_samples[-1] -= excess
|
238 |
+
|
239 |
+
# Distribute samples
|
240 |
+
start_idx = 0
|
241 |
+
for client_idx, num_samples in enumerate(client_samples):
|
242 |
+
if num_samples > 0:
|
243 |
+
end_idx = start_idx + num_samples
|
244 |
+
client_data[client_idx].extend(class_data[start_idx:end_idx])
|
245 |
+
start_idx = end_idx
|
246 |
+
|
247 |
+
# Convert indices back to paths and labels
|
248 |
+
client_datasets = []
|
249 |
+
for client_idx, client_indices in enumerate(client_data):
|
250 |
+
if len(client_indices) > 0: # Accept any client with at least some data
|
251 |
+
client_paths = [image_paths[i] for i in client_indices]
|
252 |
+
client_labels = [labels[i] for i in client_indices]
|
253 |
+
client_datasets.append((client_paths, client_labels))
|
254 |
+
print(f"Client {client_idx} will have {len(client_indices)} samples")
|
255 |
+
else:
|
256 |
+
print(f"Warning: Client {client_idx} has no samples assigned")
|
257 |
+
|
258 |
+
# If we don't have enough clients, redistribute the data more evenly
|
259 |
+
if len(client_datasets) < num_clients:
|
260 |
+
print(f"Warning: Only {len(client_datasets)} clients have sufficient data. Redistributing...")
|
261 |
+
return redistribute_data_evenly(image_paths, labels, num_clients)
|
262 |
+
|
263 |
+
return client_datasets
|
264 |
+
|
265 |
+
|
266 |
+
def safe_train_test_split(paths, labels, test_size=0.2, random_state=None):
|
267 |
+
"""
|
268 |
+
Safely split data into train/test, handling classes with insufficient samples
|
269 |
+
"""
|
270 |
+
# Count samples per class
|
271 |
+
class_counts = Counter(labels)
|
272 |
+
|
273 |
+
# Check if we can do stratified split
|
274 |
+
min_class_size = min(class_counts.values())
|
275 |
+
can_stratify = min_class_size >= 2
|
276 |
+
|
277 |
+
if can_stratify:
|
278 |
+
try:
|
279 |
+
return train_test_split(
|
280 |
+
paths, labels,
|
281 |
+
test_size=test_size,
|
282 |
+
random_state=random_state,
|
283 |
+
stratify=labels
|
284 |
+
)
|
285 |
+
except ValueError as e:
|
286 |
+
print(f" ⚠️ Stratified split failed: {e}")
|
287 |
+
can_stratify = False
|
288 |
+
|
289 |
+
if not can_stratify:
|
290 |
+
print(f" 📊 Using random split (some classes have <2 samples)")
|
291 |
+
print(f" 📈 Class distribution: {dict(class_counts)}")
|
292 |
+
|
293 |
+
# Use random split without stratification
|
294 |
+
return train_test_split(
|
295 |
+
paths, labels,
|
296 |
+
test_size=test_size,
|
297 |
+
random_state=random_state,
|
298 |
+
stratify=None
|
299 |
+
)
|
300 |
+
|
301 |
+
|
302 |
def get_data_transforms():
|
303 |
"""Get data transformations for training and testing"""
|
304 |
|
|
|
319 |
|
320 |
return train_transform, test_transform
|
321 |
|
322 |
+
|
323 |
+
def create_client_dataloaders(num_clients, corruption_prob=0.1, alpha=0.5, datasets=None, subversions=None):
|
324 |
+
"""Create data loaders for all clients with non-IID distribution"""
|
325 |
+
|
326 |
+
# Load data from specified datasets and subversions
|
327 |
+
all_paths, all_labels = load_dataset_paths(datasets=datasets, subversions=subversions)
|
328 |
+
|
329 |
+
print(f"Total images loaded: {len(all_paths)}")
|
330 |
+
print(f"Unique labels: {set(all_labels)}")
|
331 |
+
|
332 |
+
if len(all_paths) == 0:
|
333 |
+
raise ValueError("No images found! Please check your dataset paths and subversions.")
|
334 |
+
|
335 |
+
# Create non-IID distribution
|
336 |
+
client_datasets = create_non_iid_distribution(all_paths, all_labels, num_clients, alpha)
|
337 |
+
|
338 |
+
print(f"Created {len(client_datasets)} client datasets")
|
339 |
+
|
340 |
+
# Get transforms
|
341 |
+
train_transform, test_transform = get_data_transforms()
|
342 |
+
|
343 |
+
# Create data loaders for each client
|
344 |
+
client_loaders = []
|
345 |
+
|
346 |
+
for i, (client_paths, client_labels) in enumerate(client_datasets):
|
347 |
+
print(f"Client {i}: {len(client_paths)} samples")
|
348 |
+
|
349 |
+
# Split into train/test for each client using safe splitting
|
350 |
+
train_paths, test_paths, train_labels, test_labels = safe_train_test_split(
|
351 |
+
client_paths, client_labels, test_size=0.2, random_state=config.SEED
|
352 |
+
)
|
353 |
+
|
354 |
+
# Create datasets
|
355 |
+
train_dataset = KidneyStoneDataset(
|
356 |
+
train_paths, train_labels,
|
357 |
+
transform=train_transform,
|
358 |
+
corruption_prob=corruption_prob
|
359 |
+
)
|
360 |
+
|
361 |
+
test_dataset = KidneyStoneDataset(
|
362 |
+
test_paths, test_labels,
|
363 |
+
transform=test_transform,
|
364 |
+
corruption_prob=0.0 # No corruption for test data
|
365 |
+
)
|
366 |
+
|
367 |
+
# Create data loaders
|
368 |
+
train_loader = DataLoader(
|
369 |
+
train_dataset,
|
370 |
+
batch_size=config.BATCH_SIZE,
|
371 |
+
shuffle=True,
|
372 |
+
num_workers=2
|
373 |
+
)
|
374 |
+
|
375 |
+
test_loader = DataLoader(
|
376 |
+
test_dataset,
|
377 |
+
batch_size=config.BATCH_SIZE,
|
378 |
+
shuffle=False,
|
379 |
+
num_workers=2
|
380 |
+
)
|
381 |
+
|
382 |
+
client_loaders.append((train_loader, test_loader))
|
383 |
+
|
384 |
+
return client_loaders
|
385 |
+
|
386 |
+
|
387 |
+
def create_global_test_loader(datasets=None, subversions=None):
|
388 |
+
"""Create a global test loader for evaluation"""
|
389 |
+
|
390 |
+
# Load data from specified datasets and subversions
|
391 |
+
all_paths, all_labels = load_dataset_paths(datasets=datasets, subversions=subversions)
|
392 |
+
|
393 |
+
if len(all_paths) == 0:
|
394 |
+
raise ValueError("No images found for global test loader! Please check your dataset paths and subversions.")
|
395 |
+
|
396 |
+
# Use a subset for global testing with safe splitting
|
397 |
+
_, test_paths, _, test_labels = safe_train_test_split(
|
398 |
+
all_paths, all_labels, test_size=0.1, random_state=config.SEED
|
399 |
+
)
|
400 |
+
|
401 |
+
_, test_transform = get_data_transforms()
|
402 |
+
|
403 |
+
test_dataset = KidneyStoneDataset(
|
404 |
+
test_paths, test_labels,
|
405 |
+
transform=test_transform,
|
406 |
+
corruption_prob=0.0
|
407 |
+
)
|
408 |
+
|
409 |
+
test_loader = DataLoader(
|
410 |
+
test_dataset,
|
411 |
+
batch_size=config.BATCH_SIZE,
|
412 |
+
shuffle=False,
|
413 |
+
num_workers=2
|
414 |
+
)
|
415 |
+
|
416 |
+
return test_loader
|