pupil_repo / losses.py
g30rv17ys's picture
Add files using upload-large-folder tool
fd4b932 verified
raw
history blame
673 Bytes
import tensorflow as tf
from tensorflow.keras import backend as K
def boundary_loss(y_true, y_pred):
"""Additional loss focusing on boundaries"""
# Compute gradients
dy_true, dx_true = tf.image.image_gradients(y_true)
dy_pred, dx_pred = tf.image.image_gradients(y_pred)
# Compute boundary loss
loss = tf.reduce_mean(tf.abs(dy_pred - dy_true) + tf.abs(dx_pred - dx_true))
return loss * 0.5 # weight factor
def enhanced_binary_crossentropy(y_true, y_pred):
"""Combine standard BCE with boundary loss"""
bce = tf.keras.losses.binary_crossentropy(y_true, y_pred)
boundary = boundary_loss(y_true, y_pred)
return bce + boundary