|
|
|
|
|
import os |
|
import math |
|
import pandas as pd |
|
import tensorflow as tf |
|
from functools import partial |
|
|
|
def validate_data_files(data_df): |
|
"""Validate all files exist before starting training""" |
|
missing_files = [] |
|
|
|
for _, row in data_df.iterrows(): |
|
if not os.path.exists(row['filename']): |
|
missing_files.append(row['filename']) |
|
if not os.path.exists(row['target']): |
|
missing_files.append(row['target']) |
|
|
|
if missing_files: |
|
print("Missing files:") |
|
for f in missing_files[:10]: |
|
print(f" {f}") |
|
if len(missing_files) > 10: |
|
print(f" ... and {len(missing_files)-10} more") |
|
raise FileNotFoundError(f"Found {len(missing_files)} missing files") |
|
|
|
def _get_pupil_position(pmap, datum, x_shape): |
|
total_mass = tf.reduce_sum(pmap) |
|
if total_mass > 0: |
|
shape = tf.shape(pmap) |
|
h, w = shape[0], shape[1] |
|
ii, jj = tf.meshgrid(tf.range(h), tf.range(w), indexing='ij') |
|
y = tf.reduce_sum(tf.cast(ii, 'float32') * pmap) / total_mass |
|
x = tf.reduce_sum(tf.cast(jj, 'float32') * pmap) / total_mass |
|
return tf.stack((y, x)) |
|
|
|
if 'roi_x' in datum and 'roi_y' in datum and 'roi_w' in datum: |
|
roi_x = tf.cast(datum['roi_x'], 'float32') |
|
roi_y = tf.cast(datum['roi_y'], 'float32') |
|
half = tf.cast(datum['roi_w'] / 2, 'float32') |
|
result = tf.stack((roi_y + half, roi_x + half)) |
|
else: |
|
result = tf.cast(tf.stack((x_shape[0] / 2, x_shape[1] / 2)), dtype='float32') |
|
|
|
return result |
|
|
|
def additional_augmentations(image, mask, p=0.3): |
|
|
|
if tf.random.uniform([]) > p: |
|
return image, mask |
|
|
|
|
|
if tf.random.uniform([]) < 0.3: |
|
noise = tf.random.normal(tf.shape(image), mean=0, stddev=0.05) |
|
image = image + noise |
|
image = tf.clip_by_value(image, 0, 1) |
|
|
|
|
|
if tf.random.uniform([]) < 0.3: |
|
kernel_size = 3 |
|
sigma = tf.random.uniform([], 0, 1.0) |
|
x = tf.range(-kernel_size // 2 + 1, kernel_size // 2 + 1, dtype=tf.float32) |
|
gaussian = tf.exp(-(x ** 2) / (2 * sigma ** 2)) |
|
gaussian = gaussian / tf.reduce_sum(gaussian) |
|
gaussian = tf.reshape(gaussian, [kernel_size, 1]) |
|
gaussian_kernel = gaussian @ tf.transpose(gaussian) |
|
gaussian_kernel = tf.reshape(gaussian_kernel, [kernel_size, kernel_size, 1, 1]) |
|
image = tf.nn.conv2d(tf.expand_dims(image, 0), gaussian_kernel, |
|
strides=[1,1,1,1], padding='SAME')[0] |
|
|
|
|
|
if tf.random.uniform([]) < 0.3: |
|
kernel = tf.constant([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]], dtype=tf.float32) |
|
kernel = tf.reshape(kernel, [3, 3, 1, 1]) |
|
image = tf.nn.conv2d(tf.expand_dims(image, 0), kernel, |
|
strides=[1,1,1,1], padding='SAME')[0] |
|
image = tf.clip_by_value(image, 0, 1) |
|
|
|
return image, mask |
|
|
|
@tf.function |
|
def load_datum(datum, x_shape=(128, 128, 1), augment=False): |
|
try: |
|
x = tf.io.read_file(datum['filename']) |
|
y = tf.io.read_file(datum['target']) |
|
|
|
|
|
x = tf.io.decode_image(x, channels=1, dtype='float32', expand_animations=False) |
|
y = tf.io.decode_image(y, dtype='float32', expand_animations=False) |
|
|
|
|
|
h = tf.shape(x)[0] |
|
w = tf.shape(x)[1] |
|
|
|
|
|
pupil_map = y[:, :, 0] |
|
pupil_area = tf.reduce_sum(pupil_map) |
|
pupil_pos_yx = _get_pupil_position(pupil_map, datum, x_shape) |
|
|
|
|
|
target_size = tf.minimum(tf.minimum(h, w), x_shape[0]) |
|
|
|
if not augment: |
|
|
|
h_start = (h - target_size) // 2 |
|
w_start = (w - target_size) // 2 |
|
else: |
|
|
|
h_start = tf.random.uniform([], 0, h - target_size + 1, dtype=tf.int32) |
|
w_start = tf.random.uniform([], 0, w - target_size + 1, dtype=tf.int32) |
|
|
|
|
|
x = tf.image.crop_to_bounding_box(x, h_start, w_start, target_size, target_size) |
|
y = tf.image.crop_to_bounding_box(y, h_start, w_start, target_size, target_size) |
|
|
|
if augment: |
|
|
|
k = tf.random.uniform([], 0, 4, dtype=tf.int32) |
|
x = tf.image.rot90(x, k=k) |
|
y = tf.image.rot90(y, k=k) |
|
|
|
|
|
if tf.random.uniform([]) < 0.5: |
|
x = tf.image.flip_left_right(x) |
|
y = tf.image.flip_left_right(y) |
|
|
|
if tf.random.uniform([]) < 0.5: |
|
x = tf.image.flip_up_down(x) |
|
y = tf.image.flip_up_down(y) |
|
|
|
|
|
x, y = additional_augmentations(x, y) |
|
|
|
|
|
new_pupil_map = y[:, :, 0] |
|
new_pupil_area = tf.reduce_sum(new_pupil_map) |
|
eye = (new_pupil_area / pupil_area) if pupil_area > 0 else 0. |
|
|
|
|
|
datum_eye = tf.cast(datum['eye'], 'float32') |
|
datum_blink = tf.cast(datum['blink'], 'float32') |
|
|
|
|
|
if datum_eye == 0: |
|
datum_blink = 0. |
|
if (datum_eye == 1) & (datum_blink == 0): |
|
datum_eye = eye |
|
|
|
|
|
if target_size != x_shape[0]: |
|
x = tf.image.resize(x, [x_shape[0], x_shape[0]]) |
|
y = tf.image.resize(y, [x_shape[0], x_shape[0]]) |
|
|
|
y = y[:, :, :1] |
|
y2 = tf.stack((datum_eye, datum_blink)) |
|
|
|
return x, y, y2 |
|
|
|
except Exception as e: |
|
print(f"Error processing datum: {str(e)}") |
|
raise |
|
|
|
def get_loader(dataframe, batch_size=8, shuffle=False, **kwargs): |
|
categories = dataframe.exp.values |
|
dataset = tf.data.Dataset.from_tensor_slices(dict(dataframe)) |
|
|
|
if shuffle: |
|
dataset = dataset.shuffle(1000) |
|
|
|
dataset = dataset.map( |
|
partial(load_datum, **kwargs), |
|
num_parallel_calls=tf.data.AUTOTUNE, |
|
deterministic=not shuffle |
|
) |
|
dataset = dataset.batch(batch_size) |
|
|
|
def _pack_targets(*ins): |
|
inputs = ins[0] |
|
targets = {'mask': ins[1], 'tags': ins[2]} |
|
return [inputs, targets] |
|
|
|
dataset = dataset.map( |
|
_pack_targets, |
|
num_parallel_calls=tf.data.AUTOTUNE, |
|
deterministic=not shuffle |
|
) |
|
dataset = dataset.prefetch(tf.data.AUTOTUNE) |
|
return dataset, categories |
|
|
|
def load_datasets(dataset_dirs): |
|
def _load_and_prepare_annotations(dataset_dir): |
|
|
|
dataset_dir = os.path.normpath(dataset_dir) |
|
|
|
data_path = os.path.join(dataset_dir, 'annotation', 'annotations.csv') |
|
if not os.path.exists(data_path): |
|
raise FileNotFoundError(f"Annotations file not found: {data_path}") |
|
|
|
data = pd.read_csv(data_path) |
|
|
|
|
|
png_dir = os.path.join(dataset_dir, 'annotation', 'png') |
|
full_frames_dir = os.path.join(dataset_dir, 'fullFrames') |
|
os.makedirs(png_dir, exist_ok=True) |
|
os.makedirs(full_frames_dir, exist_ok=True) |
|
|
|
|
|
valid_data = data[data.apply(lambda x: os.path.exists(os.path.join(full_frames_dir, os.path.basename(x['filename']))), axis=1)] |
|
|
|
if len(valid_data) == 0: |
|
raise ValueError(f"No valid image files found in {full_frames_dir}") |
|
|
|
|
|
valid_data['target'] = valid_data['filename'].apply( |
|
lambda x: os.path.join(png_dir, os.path.splitext(os.path.basename(x))[0] + '.png') |
|
) |
|
|
|
valid_data['filename'] = valid_data['filename'].apply( |
|
lambda x: os.path.join(full_frames_dir, os.path.basename(x)) |
|
) |
|
|
|
return valid_data |
|
|
|
all_data = [] |
|
for d in dataset_dirs: |
|
try: |
|
dataset = _load_and_prepare_annotations(d) |
|
all_data.append(dataset) |
|
except Exception as e: |
|
print(f"Error loading dataset from {d}: {str(e)}") |
|
continue |
|
|
|
if not all_data: |
|
raise ValueError("No valid datasets found in any of the provided directories") |
|
|
|
dataset = pd.concat(all_data) |
|
dataset['sub'] = dataset['sub'].astype(str) |
|
|
|
print(f"Found {len(dataset)} valid image pairs") |
|
return dataset |