pupil_repo / dataloader.py
g30rv17ys's picture
Add files using upload-large-folder tool
eec42bd verified
raw
history blame
8.74 kB
# dataloader.py
import os
import math
import pandas as pd
import tensorflow as tf
from functools import partial
def validate_data_files(data_df):
"""Validate all files exist before starting training"""
missing_files = []
for _, row in data_df.iterrows():
if not os.path.exists(row['filename']):
missing_files.append(row['filename'])
if not os.path.exists(row['target']):
missing_files.append(row['target'])
if missing_files:
print("Missing files:")
for f in missing_files[:10]:
print(f" {f}")
if len(missing_files) > 10:
print(f" ... and {len(missing_files)-10} more")
raise FileNotFoundError(f"Found {len(missing_files)} missing files")
def _get_pupil_position(pmap, datum, x_shape):
total_mass = tf.reduce_sum(pmap)
if total_mass > 0:
shape = tf.shape(pmap)
h, w = shape[0], shape[1]
ii, jj = tf.meshgrid(tf.range(h), tf.range(w), indexing='ij')
y = tf.reduce_sum(tf.cast(ii, 'float32') * pmap) / total_mass
x = tf.reduce_sum(tf.cast(jj, 'float32') * pmap) / total_mass
return tf.stack((y, x))
if 'roi_x' in datum and 'roi_y' in datum and 'roi_w' in datum:
roi_x = tf.cast(datum['roi_x'], 'float32')
roi_y = tf.cast(datum['roi_y'], 'float32')
half = tf.cast(datum['roi_w'] / 2, 'float32')
result = tf.stack((roi_y + half, roi_x + half))
else: # fallback to center of the image
result = tf.cast(tf.stack((x_shape[0] / 2, x_shape[1] / 2)), dtype='float32')
return result
def additional_augmentations(image, mask, p=0.3):
# Keep original image with probability 1-p
if tf.random.uniform([]) > p:
return image, mask
# Random noise augmentation
if tf.random.uniform([]) < 0.3:
noise = tf.random.normal(tf.shape(image), mean=0, stddev=0.05)
image = image + noise
image = tf.clip_by_value(image, 0, 1)
# Random blur using gaussian filter
if tf.random.uniform([]) < 0.3:
kernel_size = 3
sigma = tf.random.uniform([], 0, 1.0)
x = tf.range(-kernel_size // 2 + 1, kernel_size // 2 + 1, dtype=tf.float32)
gaussian = tf.exp(-(x ** 2) / (2 * sigma ** 2))
gaussian = gaussian / tf.reduce_sum(gaussian)
gaussian = tf.reshape(gaussian, [kernel_size, 1])
gaussian_kernel = gaussian @ tf.transpose(gaussian)
gaussian_kernel = tf.reshape(gaussian_kernel, [kernel_size, kernel_size, 1, 1])
image = tf.nn.conv2d(tf.expand_dims(image, 0), gaussian_kernel,
strides=[1,1,1,1], padding='SAME')[0]
# Random sharpening
if tf.random.uniform([]) < 0.3:
kernel = tf.constant([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]], dtype=tf.float32)
kernel = tf.reshape(kernel, [3, 3, 1, 1])
image = tf.nn.conv2d(tf.expand_dims(image, 0), kernel,
strides=[1,1,1,1], padding='SAME')[0]
image = tf.clip_by_value(image, 0, 1)
return image, mask
@tf.function
def load_datum(datum, x_shape=(128, 128, 1), augment=False):
try:
x = tf.io.read_file(datum['filename'])
y = tf.io.read_file(datum['target'])
# HWC [0,1] float32
x = tf.io.decode_image(x, channels=1, dtype='float32', expand_animations=False)
y = tf.io.decode_image(y, dtype='float32', expand_animations=False)
# Get image dimensions
h = tf.shape(x)[0]
w = tf.shape(x)[1]
# Extract pupil information
pupil_map = y[:, :, 0]
pupil_area = tf.reduce_sum(pupil_map)
pupil_pos_yx = _get_pupil_position(pupil_map, datum, x_shape)
# Target size we want to achieve
target_size = tf.minimum(tf.minimum(h, w), x_shape[0])
if not augment:
# Calculate center crop
h_start = (h - target_size) // 2
w_start = (w - target_size) // 2
else:
# Random crop within safe bounds considering pupil position
h_start = tf.random.uniform([], 0, h - target_size + 1, dtype=tf.int32)
w_start = tf.random.uniform([], 0, w - target_size + 1, dtype=tf.int32)
# Perform crop
x = tf.image.crop_to_bounding_box(x, h_start, w_start, target_size, target_size)
y = tf.image.crop_to_bounding_box(y, h_start, w_start, target_size, target_size)
if augment:
# Rotation with arbitrary angles
k = tf.random.uniform([], 0, 4, dtype=tf.int32)
x = tf.image.rot90(x, k=k)
y = tf.image.rot90(y, k=k)
# Flips with pupil position consideration
if tf.random.uniform([]) < 0.5:
x = tf.image.flip_left_right(x)
y = tf.image.flip_left_right(y)
if tf.random.uniform([]) < 0.5:
x = tf.image.flip_up_down(x)
y = tf.image.flip_up_down(y)
# Apply additional augmentations
x, y = additional_augmentations(x, y)
# Calculate pupil visibility after transformation
new_pupil_map = y[:, :, 0]
new_pupil_area = tf.reduce_sum(new_pupil_map)
eye = (new_pupil_area / pupil_area) if pupil_area > 0 else 0.
# Process eye and blink information
datum_eye = tf.cast(datum['eye'], 'float32')
datum_blink = tf.cast(datum['blink'], 'float32')
# Handle blink cases
if datum_eye == 0:
datum_blink = 0.
if (datum_eye == 1) & (datum_blink == 0):
datum_eye = eye
# Resize if needed
if target_size != x_shape[0]:
x = tf.image.resize(x, [x_shape[0], x_shape[0]])
y = tf.image.resize(y, [x_shape[0], x_shape[0]])
y = y[:, :, :1]
y2 = tf.stack((datum_eye, datum_blink))
return x, y, y2
except Exception as e:
print(f"Error processing datum: {str(e)}")
raise
def get_loader(dataframe, batch_size=8, shuffle=False, **kwargs):
categories = dataframe.exp.values
dataset = tf.data.Dataset.from_tensor_slices(dict(dataframe))
if shuffle:
dataset = dataset.shuffle(1000)
dataset = dataset.map(
partial(load_datum, **kwargs),
num_parallel_calls=tf.data.AUTOTUNE,
deterministic=not shuffle
)
dataset = dataset.batch(batch_size)
def _pack_targets(*ins):
inputs = ins[0]
targets = {'mask': ins[1], 'tags': ins[2]}
return [inputs, targets]
dataset = dataset.map(
_pack_targets,
num_parallel_calls=tf.data.AUTOTUNE,
deterministic=not shuffle
)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
return dataset, categories
def load_datasets(dataset_dirs):
def _load_and_prepare_annotations(dataset_dir):
# Normalize path
dataset_dir = os.path.normpath(dataset_dir)
data_path = os.path.join(dataset_dir, 'annotation', 'annotations.csv')
if not os.path.exists(data_path):
raise FileNotFoundError(f"Annotations file not found: {data_path}")
data = pd.read_csv(data_path)
# Create directories if they don't exist
png_dir = os.path.join(dataset_dir, 'annotation', 'png')
full_frames_dir = os.path.join(dataset_dir, 'fullFrames')
os.makedirs(png_dir, exist_ok=True)
os.makedirs(full_frames_dir, exist_ok=True)
# Filter valid files before creating paths
valid_data = data[data.apply(lambda x: os.path.exists(os.path.join(full_frames_dir, os.path.basename(x['filename']))), axis=1)]
if len(valid_data) == 0:
raise ValueError(f"No valid image files found in {full_frames_dir}")
# Create target paths
valid_data['target'] = valid_data['filename'].apply(
lambda x: os.path.join(png_dir, os.path.splitext(os.path.basename(x))[0] + '.png')
)
valid_data['filename'] = valid_data['filename'].apply(
lambda x: os.path.join(full_frames_dir, os.path.basename(x))
)
return valid_data
all_data = []
for d in dataset_dirs:
try:
dataset = _load_and_prepare_annotations(d)
all_data.append(dataset)
except Exception as e:
print(f"Error loading dataset from {d}: {str(e)}")
continue
if not all_data:
raise ValueError("No valid datasets found in any of the provided directories")
dataset = pd.concat(all_data)
dataset['sub'] = dataset['sub'].astype(str)
print(f"Found {len(dataset)} valid image pairs")
return dataset