pupil_repo / convert_model.py
g30rv17ys's picture
Add files using upload-large-folder tool
fd4b932 verified
raw
history blame
2.16 kB
import tensorflow as tf
from tensorflow.keras import backend as K
from adabelief_tf import AdaBeliefOptimizer
def iou_coef(y_true, y_pred):
y_true = tf.cast(y_true, tf.float32)
y_pred = tf.cast(y_pred, tf.float32)
intersection = K.sum(K.abs(y_true * y_pred), axis=[1, 2, 3])
union = K.sum(y_true, axis=[1, 2, 3]) + K.sum(y_pred, axis=[1, 2, 3]) - intersection
return K.mean((intersection + 1e-6) / (union + 1e-6))
def dice_coef(y_true, y_pred):
y_true = tf.cast(y_true, tf.float32)
y_pred = tf.cast(y_pred, tf.float32)
intersection = K.sum(K.abs(y_true * y_pred), axis=[1, 2, 3])
return K.mean((2. * intersection + 1e-6) / (K.sum(y_true, axis=[1, 2, 3]) + K.sum(y_pred, axis=[1, 2, 3]) + 1e-6))
def boundary_loss(y_true, y_pred):
y_true = tf.cast(y_true, tf.float32)
y_pred = tf.cast(y_pred, tf.float32)
dy_true, dx_true = tf.image.image_gradients(y_true)
dy_pred, dx_pred = tf.image.image_gradients(y_pred)
loss = tf.reduce_mean(tf.abs(dy_pred - dy_true) + tf.abs(dx_pred - dx_true))
return loss * 0.5
def enhanced_binary_crossentropy(y_true, y_pred):
y_true = tf.cast(y_true, tf.float32)
y_pred = tf.cast(y_pred, tf.float32)
bce = tf.keras.losses.binary_crossentropy(y_true, y_pred)
boundary = boundary_loss(y_true, y_pred)
return bce + boundary
def hard_swish(x):
return x * tf.nn.relu6(x + 3) * (1. / 6.)
# Path to your current .keras model
keras_path = 'runs/b32_c-conv_d-|root|meye|data|NN_human_mouse_eyes|_g1.5_l0.001_num_c1_num_f16_num_s5_r128_se23_sp-random_up-relu_us0/best_model.keras'
# Load the model with custom objects
custom_objects = {
'AdaBeliefOptimizer': AdaBeliefOptimizer,
'iou_coef': iou_coef,
'dice_coef': dice_coef,
'hard_swish': hard_swish,
'enhanced_binary_crossentropy': enhanced_binary_crossentropy,
'boundary_loss': boundary_loss
}
print("Loading model from:", keras_path)
model = tf.keras.models.load_model(keras_path, custom_objects=custom_objects)
# Save as .h5
h5_path = keras_path.replace('.keras', '.h5')
print("Saving model to:", h5_path)
model.save(h5_path, save_format='h5')
print("Conversion complete!")