# dataloader.py import os import math import pandas as pd import tensorflow as tf from functools import partial def validate_data_files(data_df): """Validate all files exist before starting training""" missing_files = [] for _, row in data_df.iterrows(): if not os.path.exists(row['filename']): missing_files.append(row['filename']) if not os.path.exists(row['target']): missing_files.append(row['target']) if missing_files: print("Missing files:") for f in missing_files[:10]: print(f" {f}") if len(missing_files) > 10: print(f" ... and {len(missing_files)-10} more") raise FileNotFoundError(f"Found {len(missing_files)} missing files") def _get_pupil_position(pmap, datum, x_shape): total_mass = tf.reduce_sum(pmap) if total_mass > 0: shape = tf.shape(pmap) h, w = shape[0], shape[1] ii, jj = tf.meshgrid(tf.range(h), tf.range(w), indexing='ij') y = tf.reduce_sum(tf.cast(ii, 'float32') * pmap) / total_mass x = tf.reduce_sum(tf.cast(jj, 'float32') * pmap) / total_mass return tf.stack((y, x)) if 'roi_x' in datum and 'roi_y' in datum and 'roi_w' in datum: roi_x = tf.cast(datum['roi_x'], 'float32') roi_y = tf.cast(datum['roi_y'], 'float32') half = tf.cast(datum['roi_w'] / 2, 'float32') result = tf.stack((roi_y + half, roi_x + half)) else: # fallback to center of the image result = tf.cast(tf.stack((x_shape[0] / 2, x_shape[1] / 2)), dtype='float32') return result def additional_augmentations(image, mask, p=0.3): # Keep original image with probability 1-p if tf.random.uniform([]) > p: return image, mask # Random noise augmentation if tf.random.uniform([]) < 0.3: noise = tf.random.normal(tf.shape(image), mean=0, stddev=0.05) image = image + noise image = tf.clip_by_value(image, 0, 1) # Random blur using gaussian filter if tf.random.uniform([]) < 0.3: kernel_size = 3 sigma = tf.random.uniform([], 0, 1.0) x = tf.range(-kernel_size // 2 + 1, kernel_size // 2 + 1, dtype=tf.float32) gaussian = tf.exp(-(x ** 2) / (2 * sigma ** 2)) gaussian = gaussian / tf.reduce_sum(gaussian) gaussian = tf.reshape(gaussian, [kernel_size, 1]) gaussian_kernel = gaussian @ tf.transpose(gaussian) gaussian_kernel = tf.reshape(gaussian_kernel, [kernel_size, kernel_size, 1, 1]) image = tf.nn.conv2d(tf.expand_dims(image, 0), gaussian_kernel, strides=[1,1,1,1], padding='SAME')[0] # Random sharpening if tf.random.uniform([]) < 0.3: kernel = tf.constant([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]], dtype=tf.float32) kernel = tf.reshape(kernel, [3, 3, 1, 1]) image = tf.nn.conv2d(tf.expand_dims(image, 0), kernel, strides=[1,1,1,1], padding='SAME')[0] image = tf.clip_by_value(image, 0, 1) return image, mask @tf.function def load_datum(datum, x_shape=(128, 128, 1), augment=False): try: x = tf.io.read_file(datum['filename']) y = tf.io.read_file(datum['target']) # HWC [0,1] float32 x = tf.io.decode_image(x, channels=1, dtype='float32', expand_animations=False) y = tf.io.decode_image(y, dtype='float32', expand_animations=False) # Get image dimensions h = tf.shape(x)[0] w = tf.shape(x)[1] # Extract pupil information pupil_map = y[:, :, 0] pupil_area = tf.reduce_sum(pupil_map) pupil_pos_yx = _get_pupil_position(pupil_map, datum, x_shape) # Target size we want to achieve target_size = tf.minimum(tf.minimum(h, w), x_shape[0]) if not augment: # Calculate center crop h_start = (h - target_size) // 2 w_start = (w - target_size) // 2 else: # Random crop within safe bounds considering pupil position h_start = tf.random.uniform([], 0, h - target_size + 1, dtype=tf.int32) w_start = tf.random.uniform([], 0, w - target_size + 1, dtype=tf.int32) # Perform crop x = tf.image.crop_to_bounding_box(x, h_start, w_start, target_size, target_size) y = tf.image.crop_to_bounding_box(y, h_start, w_start, target_size, target_size) if augment: # Rotation with arbitrary angles k = tf.random.uniform([], 0, 4, dtype=tf.int32) x = tf.image.rot90(x, k=k) y = tf.image.rot90(y, k=k) # Flips with pupil position consideration if tf.random.uniform([]) < 0.5: x = tf.image.flip_left_right(x) y = tf.image.flip_left_right(y) if tf.random.uniform([]) < 0.5: x = tf.image.flip_up_down(x) y = tf.image.flip_up_down(y) # Apply additional augmentations x, y = additional_augmentations(x, y) # Calculate pupil visibility after transformation new_pupil_map = y[:, :, 0] new_pupil_area = tf.reduce_sum(new_pupil_map) eye = (new_pupil_area / pupil_area) if pupil_area > 0 else 0. # Process eye and blink information datum_eye = tf.cast(datum['eye'], 'float32') datum_blink = tf.cast(datum['blink'], 'float32') # Handle blink cases if datum_eye == 0: datum_blink = 0. if (datum_eye == 1) & (datum_blink == 0): datum_eye = eye # Resize if needed if target_size != x_shape[0]: x = tf.image.resize(x, [x_shape[0], x_shape[0]]) y = tf.image.resize(y, [x_shape[0], x_shape[0]]) y = y[:, :, :1] y2 = tf.stack((datum_eye, datum_blink)) return x, y, y2 except Exception as e: print(f"Error processing datum: {str(e)}") raise def get_loader(dataframe, batch_size=8, shuffle=False, **kwargs): categories = dataframe.exp.values dataset = tf.data.Dataset.from_tensor_slices(dict(dataframe)) if shuffle: dataset = dataset.shuffle(1000) dataset = dataset.map( partial(load_datum, **kwargs), num_parallel_calls=tf.data.AUTOTUNE, deterministic=not shuffle ) dataset = dataset.batch(batch_size) def _pack_targets(*ins): inputs = ins[0] targets = {'mask': ins[1], 'tags': ins[2]} return [inputs, targets] dataset = dataset.map( _pack_targets, num_parallel_calls=tf.data.AUTOTUNE, deterministic=not shuffle ) dataset = dataset.prefetch(tf.data.AUTOTUNE) return dataset, categories def load_datasets(dataset_dirs): def _load_and_prepare_annotations(dataset_dir): # Normalize path dataset_dir = os.path.normpath(dataset_dir) data_path = os.path.join(dataset_dir, 'annotation', 'annotations.csv') if not os.path.exists(data_path): raise FileNotFoundError(f"Annotations file not found: {data_path}") data = pd.read_csv(data_path) # Create directories if they don't exist png_dir = os.path.join(dataset_dir, 'annotation', 'png') full_frames_dir = os.path.join(dataset_dir, 'fullFrames') os.makedirs(png_dir, exist_ok=True) os.makedirs(full_frames_dir, exist_ok=True) # Filter valid files before creating paths valid_data = data[data.apply(lambda x: os.path.exists(os.path.join(full_frames_dir, os.path.basename(x['filename']))), axis=1)] if len(valid_data) == 0: raise ValueError(f"No valid image files found in {full_frames_dir}") # Create target paths valid_data['target'] = valid_data['filename'].apply( lambda x: os.path.join(png_dir, os.path.splitext(os.path.basename(x))[0] + '.png') ) valid_data['filename'] = valid_data['filename'].apply( lambda x: os.path.join(full_frames_dir, os.path.basename(x)) ) return valid_data all_data = [] for d in dataset_dirs: try: dataset = _load_and_prepare_annotations(d) all_data.append(dataset) except Exception as e: print(f"Error loading dataset from {d}: {str(e)}") continue if not all_data: raise ValueError("No valid datasets found in any of the provided directories") dataset = pd.concat(all_data) dataset['sub'] = dataset['sub'].astype(str) print(f"Found {len(dataset)} valid image pairs") return dataset