{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "*#Image to Image Translation#*" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: tensorflow==2.15.0 in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (2.15.0)\n", "Requirement already satisfied: tensorflow-intel==2.15.0 in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from tensorflow==2.15.0) (2.15.0)\n", "Requirement already satisfied: grpcio<2.0,>=1.24.3 in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from tensorflow-intel==2.15.0->tensorflow==2.15.0) (1.70.0)\n", "Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from tensorflow-intel==2.15.0->tensorflow==2.15.0) (0.31.0)\n", "Requirement already satisfied: numpy<2.0.0,>=1.23.5 in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from tensorflow-intel==2.15.0->tensorflow==2.15.0) (1.26.4)\n", "Requirement already satisfied: packaging in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from tensorflow-intel==2.15.0->tensorflow==2.15.0) (24.2)\n", "Requirement already satisfied: tensorflow-estimator<2.16,>=2.15.0 in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from tensorflow-intel==2.15.0->tensorflow==2.15.0) (2.15.0)\n", "Requirement already satisfied: libclang>=13.0.0 in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from tensorflow-intel==2.15.0->tensorflow==2.15.0) (18.1.1)\n", "Requirement already satisfied: six>=1.12.0 in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from tensorflow-intel==2.15.0->tensorflow==2.15.0) (1.17.0)\n", "Requirement already satisfied: absl-py>=1.0.0 in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from tensorflow-intel==2.15.0->tensorflow==2.15.0) (2.1.0)\n", "Requirement already satisfied: setuptools in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from tensorflow-intel==2.15.0->tensorflow==2.15.0) (65.5.0)\n", "Requirement already satisfied: h5py>=2.9.0 in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from tensorflow-intel==2.15.0->tensorflow==2.15.0) (3.12.1)\n", "Requirement already satisfied: gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from tensorflow-intel==2.15.0->tensorflow==2.15.0) (0.6.0)\n", "Requirement already satisfied: flatbuffers>=23.5.26 in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from tensorflow-intel==2.15.0->tensorflow==2.15.0) (25.1.24)\n", "Requirement already satisfied: google-pasta>=0.1.1 in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from tensorflow-intel==2.15.0->tensorflow==2.15.0) (0.2.0)\n", "Requirement already satisfied: termcolor>=1.1.0 in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from tensorflow-intel==2.15.0->tensorflow==2.15.0) (2.5.0)\n", "Requirement already satisfied: astunparse>=1.6.0 in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from tensorflow-intel==2.15.0->tensorflow==2.15.0) (1.6.3)\n", "Requirement already satisfied: typing-extensions>=3.6.6 in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from tensorflow-intel==2.15.0->tensorflow==2.15.0) (4.12.2)\n", "Requirement already satisfied: tensorboard<2.16,>=2.15 in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from tensorflow-intel==2.15.0->tensorflow==2.15.0) (2.15.2)\n", "Requirement already satisfied: keras<2.16,>=2.15.0 in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from tensorflow-intel==2.15.0->tensorflow==2.15.0) (2.15.0)\n", "Requirement already satisfied: protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3 in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from tensorflow-intel==2.15.0->tensorflow==2.15.0) (3.20.3)\n", "Requirement already satisfied: ml-dtypes~=0.2.0 in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from tensorflow-intel==2.15.0->tensorflow==2.15.0) (0.2.0)\n", "Requirement already satisfied: wrapt<1.15,>=1.11.0 in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from tensorflow-intel==2.15.0->tensorflow==2.15.0) (1.14.1)\n", "Requirement already satisfied: opt-einsum>=2.3.2 in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from tensorflow-intel==2.15.0->tensorflow==2.15.0) (3.4.0)\n", "Requirement already satisfied: wheel<1.0,>=0.23.0 in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from astunparse>=1.6.0->tensorflow-intel==2.15.0->tensorflow==2.15.0) (0.45.1)\n", "Requirement already satisfied: requests<3,>=2.21.0 in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from tensorboard<2.16,>=2.15->tensorflow-intel==2.15.0->tensorflow==2.15.0) (2.32.3)\n", "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from tensorboard<2.16,>=2.15->tensorflow-intel==2.15.0->tensorflow==2.15.0) (0.7.2)\n", "Requirement already satisfied: werkzeug>=1.0.1 in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from tensorboard<2.16,>=2.15->tensorflow-intel==2.15.0->tensorflow==2.15.0) (3.1.3)\n", "Requirement already satisfied: markdown>=2.6.8 in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from tensorboard<2.16,>=2.15->tensorflow-intel==2.15.0->tensorflow==2.15.0) (3.7)\n", "Requirement already satisfied: google-auth<3,>=1.6.3 in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from tensorboard<2.16,>=2.15->tensorflow-intel==2.15.0->tensorflow==2.15.0) (2.38.0)\n", "Requirement already satisfied: google-auth-oauthlib<2,>=0.5 in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from tensorboard<2.16,>=2.15->tensorflow-intel==2.15.0->tensorflow==2.15.0) (1.2.1)\n", "Requirement already satisfied: cachetools<6.0,>=2.0.0 in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from google-auth<3,>=1.6.3->tensorboard<2.16,>=2.15->tensorflow-intel==2.15.0->tensorflow==2.15.0) (5.5.1)\n", "Requirement already satisfied: pyasn1-modules>=0.2.1 in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from google-auth<3,>=1.6.3->tensorboard<2.16,>=2.15->tensorflow-intel==2.15.0->tensorflow==2.15.0) (0.4.1)\n", "Requirement already satisfied: rsa<5,>=3.1.4 in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from google-auth<3,>=1.6.3->tensorboard<2.16,>=2.15->tensorflow-intel==2.15.0->tensorflow==2.15.0) (4.9)\n", "Requirement already satisfied: requests-oauthlib>=0.7.0 in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from google-auth-oauthlib<2,>=0.5->tensorboard<2.16,>=2.15->tensorflow-intel==2.15.0->tensorflow==2.15.0) (2.0.0)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from requests<3,>=2.21.0->tensorboard<2.16,>=2.15->tensorflow-intel==2.15.0->tensorflow==2.15.0) (3.4.1)\n", "Requirement already satisfied: idna<4,>=2.5 in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from requests<3,>=2.21.0->tensorboard<2.16,>=2.15->tensorflow-intel==2.15.0->tensorflow==2.15.0) (3.10)\n", "Requirement already satisfied: certifi>=2017.4.17 in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from requests<3,>=2.21.0->tensorboard<2.16,>=2.15->tensorflow-intel==2.15.0->tensorflow==2.15.0) (2025.1.31)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from requests<3,>=2.21.0->tensorboard<2.16,>=2.15->tensorflow-intel==2.15.0->tensorflow==2.15.0) (2.3.0)\n", "Requirement already satisfied: MarkupSafe>=2.1.1 in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from werkzeug>=1.0.1->tensorboard<2.16,>=2.15->tensorflow-intel==2.15.0->tensorflow==2.15.0) (3.0.2)\n", "Requirement already satisfied: pyasn1<0.7.0,>=0.4.6 in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard<2.16,>=2.15->tensorflow-intel==2.15.0->tensorflow==2.15.0) (0.6.1)\n", "Requirement already satisfied: oauthlib>=3.0.0 in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<2,>=0.5->tensorboard<2.16,>=2.15->tensorflow-intel==2.15.0->tensorflow==2.15.0) (3.2.2)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n", "[notice] A new release of pip is available: 23.0.1 -> 25.0.1\n", "[notice] To update, run: python.exe -m pip install --upgrade pip\n" ] } ], "source": [ "!pip install tensorflow==2.15.0\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: tensorflow-probability==0.23.0 in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (0.23.0)\n", "Requirement already satisfied: six>=1.10.0 in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from tensorflow-probability==0.23.0) (1.17.0)\n", "Requirement already satisfied: decorator in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from tensorflow-probability==0.23.0) (5.1.1)\n", "Requirement already satisfied: cloudpickle>=1.3 in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from tensorflow-probability==0.23.0) (3.1.1)\n", "Requirement already satisfied: gast>=0.3.2 in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from tensorflow-probability==0.23.0) (0.6.0)\n", "Requirement already satisfied: numpy>=1.13.3 in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from tensorflow-probability==0.23.0) (1.26.4)\n", "Requirement already satisfied: absl-py in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from tensorflow-probability==0.23.0) (2.1.0)\n", "Requirement already satisfied: dm-tree in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from tensorflow-probability==0.23.0) (0.1.9)\n", "Requirement already satisfied: attrs>=18.2.0 in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from dm-tree->tensorflow-probability==0.23.0) (25.1.0)\n", "Requirement already satisfied: wrapt>=1.11.2 in d:\\vs code\\web dev\\projects\\image2image\\image\\lib\\site-packages (from dm-tree->tensorflow-probability==0.23.0) (1.14.1)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n", "[notice] A new release of pip is available: 23.0.1 -> 25.0.1\n", "[notice] To update, run: python.exe -m pip install --upgrade pip\n" ] } ], "source": [ "!pip install tensorflow-probability==0.23.0" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "*1️⃣ Import Necessary Libraries*\n", "\n", "1.*TensorFlow/Keras* for building and training deep learning models.\n", "\n", "2.*NumPy* for numerical operations/n.\n", "\n", "3.*Matplotlib* for visualizing the results.\n", "\n", "4.*OpenCV/PIL* for image processing.\n", "\n", "5.*TensorFlow* Addons for additional loss functions and layers (e.g., InstanceNorm).\n", "\n", "6.*TensorFlow Datasets* (or custom loaders) to load CT & MRI images." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From d:\\VS CODE\\Web Dev\\Projects\\Image2Image\\image\\lib\\site-packages\\keras\\src\\losses.py:2976: The name tf.losses.sparse_softmax_cross_entropy is deprecated. Please use tf.compat.v1.losses.sparse_softmax_cross_entropy instead.\n", "\n", "WARNING:tensorflow:From d:\\VS CODE\\Web Dev\\Projects\\Image2Image\\image\\lib\\site-packages\\tensorflow_probability\\python\\internal\\backend\\numpy\\_utils.py:48: The name tf.logging.TaskLevelStatusMessage is deprecated. Please use tf.compat.v1.logging.TaskLevelStatusMessage instead.\n", "\n", "WARNING:tensorflow:From d:\\VS CODE\\Web Dev\\Projects\\Image2Image\\image\\lib\\site-packages\\tensorflow_probability\\python\\internal\\backend\\numpy\\_utils.py:48: The name tf.control_flow_v2_enabled is deprecated. Please use tf.compat.v1.control_flow_v2_enabled instead.\n", "\n" ] } ], "source": [ "import tensorflow as tf\n", "from tensorflow.keras import layers, Model\n", "import numpy as np\n", "import cv2\n", "import pathlib\n", "import matplotlib.pyplot as plt\n", "import tensorflow_probability as tfp\n", "\n", "tfd = tfp.distributions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "2️⃣ *Configuration (Hyperparameters)*\n", "\n", "This step defines the key settings for training.\n", "\n", "\n", "Image size: The input image dimensions.\n", "\n", "Latent dimension: The size of the encoded representation in the VAE.\n", "\n", "Learning rate: Defines how fast the model updates weights.\n", "\n", "Batch size & epochs: Training parameters." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "IMAGE_SHAPE = (256, 256, 3)\n", "LATENT_DIM = 256\n", "FILTERS = 16\n", "KERNEL = 3\n", "LEARNING_RATE = 0.0001\n", "WEIGHT_DECAY = 6e-8\n", "BATCH_SIZE = 1\n", "EPOCHS = 10" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* ===================== Architecture Components =====================*" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "*3️⃣ Sampling Layer for Variational Autoencoder (VAE)*\n", "\n", "The sampling layer is a crucial part of the VAE, where we sample from the latent space.\n", "\n", "🔹 What We Need \n", "\n", "The encoder outputs μ (mean) and σ (log variance).\n", "\n", "This layer samples from a normal distribution using the reparameterization trick." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Sampling(layers.Layer):\n", " def call(self, inputs):\n", " z_mean, z_log_var = inputs\n", " batch = tf.shape(z_mean)[0]\n", " dim = tf.shape(z_mean)[1]\n", " epsilon = tf.random.normal(shape=(batch, dim))\n", " return z_mean + tf.exp(0.5 * z_log_var) * epsilon" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It inherits from layers.Layer, meaning it's a custom layer that can be used like any other Keras layer.\n", "\n", "📍inputs is a tuple containing:\n", "\n", " z_mean: The mean vector output from the encoder.\n", "\n", " z_log_var: The log variance vector output from the encoder.\n", "\n", "📍This unpacks the inputs into two separate variables:\n", "\n", " z_mean: Represents the mean of the latent distribution.\n", "\n", " z_log_var: Represents the log variance of the latent distribution.\n", "\n", "📍Why log variance?\n", "\n", "Instead of using variance (σ²), we use log(σ²) because:\n", "\n", "Numerical Stability: Log prevents exploding/vanishing gradients.\n", "\n", "Easier Optimization: exp(log(σ²) / 2) makes variance always positive.\n", "\n", "\n", "📍This determines:\n", " batch: The number of samples in the batch.\n", " dim: The size of the latent space (e.g., 128 if LATENT_DIM = 128).\n", "\n", "\n", "📍Generates random values from a standard normal distribution (𝒩(0,1)).\n", "\n", "epsilon.shape = (batch, dim), meaning every sample gets a unique noise vector.\n", "\n", "Why do we need epsilon?\n", "\n", "Instead of directly using z_mean, we add controlled randomness to ensure the VAE learns a smooth latent space.\n", "\n", "* Reparameterization Trick*\n", "\n", " The latent space follows a normal distribution:\n", "\n", " 𝑧 ∼ 𝒩(μ, σ²)\n", "\n", " A sample is drawn from this distribution:\n", "\n", " 𝑧 = μ + σ * ε, where ε ∼ 𝒩(0,1).\n", "\n", " Since z_log_var = log(σ²), we compute:\n", "\n", " σ = exp(0.5 * log(σ²)) = exp(0.5 * z_log_var).\n", "\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "*🔹 Residual Block in Detail*\n", "\n", "This function defines a residual block, a key building block inspired by ResNet (Residual Networks). Residual blocks help in training deep neural networks efficiently by allowing gradient flow through skip connections.\n", "\n", "inputs: The input tensor (features from the previous layer).\n", "\n", "filters: The number of filters (channels) in the convolution layers.\n", "\n", "use_norm: Whether to apply Group Normalization (helps stabilize training)\n", "\n", "Step 1️⃣: First Convolution + Activation :\n", "\n", " Applies a 2D Convolution (Conv2D) with filters filters.\n", "\n", " KERNEL (not defined in this function) should be the kernel size (e.g., 3x3 or 5x5).\n", "\n", " padding='same': Ensures the output size is the same as the input.\n", "\n", " Leaky ReLU activation (alpha=0.2):\n", "\n", " Helps avoid dead neurons (better than regular ReLU).\n", " \n", " Allows a small gradient flow for negative values.\n", "\n", "Step 2️⃣: Group Normalization (Optional)\n", "\n", "Step 3️⃣: Second Convolution + Activation\n", "\n", " Applies another Conv2D layer with the same number of filters.\n", "\n", " Uses LeakyReLU again for better gradient flow.\n", "\n", " Why two convolutions?\n", "\n", " The first convolution learns low-level features.\n", " \n", " The second convolution refines the learned features.\n", "\n", "\n", "Step 5️⃣: Shortcut Connection (Skip Connection)\n", "\n", " The original input is passed through a 1x1 convolution.\n", "\n", " This matches the number of filters with the residual output.\n", "\n", " Why 1x1 convolution?\n", "\n", " Ensures the shortcut has the same number of filters as x.\n", "\n", " Helps in adjusting dimensions when the number of channels changes.\n", "\n", "Step 6️⃣: Merge Shortcut & Residual Path\n", "\n", " Merges the shortcut and residual path using element-wise maximum.\n", " \n", " Why maximum() instead of addition (+)?\n", "\n", " Prevents negative values, which can help improve training stability.\n", " \n", " Focuses on stronger features from either the residual or shortcut path.\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def residual_block(inputs, filters, use_norm=True):\n", " x = layers.Conv2D(filters, KERNEL, padding='same')(inputs)\n", " x = layers.LeakyReLU(alpha=0.2)(x)\n", " if use_norm:\n", " x = layers.GroupNormalization(groups=1)(x)\n", " x = layers.Conv2D(filters, KERNEL, padding='same')(x)\n", " x = layers.LeakyReLU(alpha=0.2)(x)\n", " if use_norm:\n", " x = layers.GroupNormalization(groups=1)(x)\n", " shortcut = layers.Conv2D(filters, 1, padding='same')(inputs)\n", " return layers.maximum([x, shortcut])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* 1️⃣ Encoder and Decoder Block*\n", "\n", "*Encoder*\n", "\n", "1️⃣ Pass Input Through Residual Block\n", "\n", "Uses a residual block (previously defined).\n", "\n", "Extracts important features while keeping the original information.\n", "\n", "Helps prevent vanishing gradients and allows deep networks to train effectively.\n", "\n", "2️⃣ Store the Skip Connection\n", "\n", "The skip connection stores the output of the residual block.\n", "\n", "It will be used later in the decoder to restore lost details.\n", "\n", "3️⃣ Downsampling (Reduce Spatial Size)\n", "\n", "Applies Max Pooling to reduce the spatial size (height & width).\n", "\n", "Why?\n", "\n", "Reduces computation.\n", "\n", "Forces the model to learn high-level features instead of pixel details.\n", "\n", "4️⃣ Return Downsampled Output & Skip Connection\n", "\n", "Outputs:\n", "\n", "x: The downsampled feature map.\n", "\n", "skip: The saved feature map (used later in the decoder).\n", "\n", "\n", "🔥 2️⃣ Decoder Block\n", "\n", "1️⃣ Upsampling (Increase Spatial Size):\n", "\n", "Uses Conv2DTranspose (transposed convolution, aka deconvolution).\n", "\n", "Upsamples the input by a factor of 2 (increases spatial size).\n", "\n", "Why?\n", "\n", "Increases resolution to match the original input image.\n", "\n", "2️⃣ Merge Skip Connection\n", "\n", "Combines the upsampled output with the skip connection.\n", "\n", "Uses element-wise maximum instead of addition.\n", "\n", "Why?\n", "\n", "Ensures the model focuses on the most important features.\n", "\n", "Prevents loss of key information during encoding.\n", "\n", "3️⃣ Apply a Residual Block\n", "\n", "Uses a residual block to refine the upsampled output.\n", "\n", "Helps recover lost details and maintain stability.\n", "\n", "4️⃣ Return the Processed Output\n", "\n", "Returns the final feature map after upsampling and refinement.\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "def encoder_block(inputs, filters, use_norm=True):\n", " x = residual_block(inputs, filters, use_norm)\n", " skip = x\n", " x = layers.MaxPooling2D()(x)\n", " return x, skip\n", "\n", "def decoder_block(inputs, skip, filters, use_norm=True):\n", " x = layers.Conv2DTranspose(filters, KERNEL, strides=2, padding='same')(inputs)\n", " x = layers.maximum([x, skip])\n", " x = residual_block(x, filters, use_norm)\n", " return x\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* ===================== Generator =====================*" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This function builds the generator model for a Variational Autoencoder (VAE) with a CycleGAN architecture. The generator is responsible for converting a CT scan into an MRI image (or vice versa) by learning to map the two domains.\n", "\n", ".\n", "\n", "🛠️ What This Function Does?\n", "\n", "It encodes an input image into a latent space.\n", "\n", "It applies variational sampling to introduce a probabilistic distribution.\n", "\n", "It decodes the latent representation back into an image.\n", "\n", "Uses skip connections to retain features across layers.\n", "\n", "1️⃣ Input Layer\n", "\n", "Defines the input tensor with the given IMAGE_SHAPE (e.g., (256, 256, 3), for RGB images).\n", "\n", "2️⃣ Encoder: Downsampling the Image\n", "\n", " Each encoder block halves the spatial resolution but doubles the filters.\n", "\n", " Stores skip connections (s1, s2, ..., s7) for later use in the decoder.\n", "\n", " After e7, the image is highly compressed into a feature map.\n", "\n", "3️⃣ Latent Space (Variational Sampling)\n", "\n", " Flattens the feature map into a 1D vector.\n", "\n", " Uses two dense layers to compute:\n", "\n", " z_mean → The mean of the latent distribution.\n", "\n", " z_log_var → The logarithm of the variance.\n", "\n", " Uses reparameterization trick (Sampling layer) to ensure backpropagation works in VAE.\n", "\n", "4️⃣ Reshape for Decoder\n", "\n", " Expands z into a 2x2 feature map to match e7 dimensions.\n", "\n", " Prepares the latent vector for decoding.\n", "\n", "5️⃣ Decoder: Upsampling the Image\n", "\n", " Each decoder block upsamples the feature map back to the original size.\n", "\n", " Uses skip connections (s1, s2, ..., s7) to restore spatial information.\n", "\n", " Mirrors the encoder process but in reverse.\n", "\n", "6️⃣ Final Output Layer\n", "\n", " Uses a Conv2D layer to produce the final RGB image.\n", " \n", " Applies sigmoid activation to ensure pixel values remain between [0,1].\n", "\n", "\n", "\n", "\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "def build_generator(name):\n", " inputs = layers.Input(IMAGE_SHAPE)\n", " \n", " # Encoder\n", " e1, s1 = encoder_block(inputs, FILTERS)\n", " e2, s2 = encoder_block(e1, FILTERS*2)\n", " e3, s3 = encoder_block(e2, FILTERS*4)\n", " e4, s4 = encoder_block(e3, FILTERS*8)\n", " e5, s5 = encoder_block(e4, FILTERS*16)\n", " e6, s6 = encoder_block(e5, FILTERS*32)\n", " e7, s7 = encoder_block(e6, FILTERS*64)\n", " \n", " # Latent Space\n", " x = layers.Flatten()(e7)\n", " z_mean = layers.Dense(LATENT_DIM, name=f\"z_mean_{name.split('_')[-1]}\")(x)\n", " z_log_var = layers.Dense(LATENT_DIM, name=f\"z_log_var_{name.split('_')[-1]}\")(x)\n", " z = Sampling()([z_mean, z_log_var])\n", " \n", " # Reshape for decoder\n", " x = layers.Dense(2 * 2 * FILTERS*64)(z)\n", " x = layers.Reshape((2, 2, FILTERS*64))(x)\n", " \n", " # Decoder\n", " d0 = decoder_block(x, s7, FILTERS*64)\n", " d1 = decoder_block(d0, s6, FILTERS*32)\n", " d2 = decoder_block(d1, s5, FILTERS*16)\n", " d3 = decoder_block(d2, s4, FILTERS*8)\n", " d4 = decoder_block(d3, s3, FILTERS*4)\n", " d5 = decoder_block(d4, s2, FILTERS*2)\n", " d6 = decoder_block(d5, s1, FILTERS)\n", " \n", " outputs = layers.Conv2D(3, KERNEL, activation='sigmoid', padding='same')(d6)\n", " return Model(inputs, [outputs, z_mean, z_log_var], name=name)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "*===================== Discriminator =====================*\n", "\n", "\n", "This function constructs the discriminator in a Generative Adversarial Network (GAN). The discriminator’s role is to classify an image as real or fake by extracting hierarchical features and making multi-scale predictions.\n", "\n", "What Does This Function Do?\n", "\n", " Extracts features from the input image using convolutional layers.\n", "\n", " Downsamples the image through multiple layers to capture both local and global features.\n", "\n", " Generates multiple outputs from different feature scales for better discrimination.\n", "\n", "1️⃣ Input Layer \n", "\n", " Defines the input tensor with a shape of IMAGE_SHAPE (e.g., (256, 256, 3) for RGB images).\n", "\n", " This means the discriminator takes an image as input.\n", "\n", "2️⃣ Feature Extraction\n", "\n", " x = inputs initializes x as the input image.\n", "\n", " features = [] creates a list to store intermediate feature map\n", "\n", "3️⃣ Initial Convolution\n", "\n", " Applies a convolutional layer (Conv2D) with FILTERS (e.g., 64 filters) to extract basic edges and textures.\n", "\n", " Uses LeakyReLU activation (alpha=0.2) instead of ReLU to allow small gradients for negative values.\n", "\n", " Stores the feature map in features.\n", "\n", "4️⃣ Downsampling Blocks (Feature Hierarchy)\n", "\n", " Defines filter_sizes, increasing filter count at each stage to learn complex features.\n", "\n", " Uses a loop to pass x through multiple encoder_block layers:\n", "\n", " Each encoder_block downsamples the feature map (reducing spatial size).\n", "\n", " Each block doubles the number of filters to capture more detailed features.\n", "\n", " Stores all extracted feature maps in features.\n", "\n", "5️⃣ Multi-Scale Outputs (Final Classification Layers)\n", "\n", " The discriminator does not produce a single output; it uses multiple feature scales.\n", "\n", " Extracts the last 4 feature maps (features[-4:]) to classify at different resolutions.\n", "\n", " Each feature map is passed through a final Conv2D layer with 1 filter to predict real vs fake scores.\n", "\n", " Stores the outputs in outputs.\n", "\n", "6️⃣ Return the Discriminator Model\n", "\n", " Creates a Keras Model that takes an image as input and outputs multiple classification scores.\n", "\n", " This helps in making fine-grained real/fake decisions.\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "def build_discriminator(name):\n", " inputs = layers.Input(IMAGE_SHAPE)\n", " \n", " # Feature extraction\n", " x = inputs\n", " features = []\n", " \n", " # Initial convolution\n", " x = layers.Conv2D(FILTERS, KERNEL, padding='same')(x)\n", " x = layers.LeakyReLU(alpha=0.2)(x)\n", " features.append(x)\n", " \n", " # Downsampling blocks\n", " filter_sizes = [FILTERS*2, FILTERS*4, FILTERS*8, FILTERS*16, FILTERS*32, FILTERS*64]\n", " for filters in filter_sizes:\n", " x, _ = encoder_block(x, filters, use_norm=False)\n", " features.append(x)\n", " \n", " # Multi-scale outputs\n", " outputs = []\n", " for i, feature in enumerate(features[-4:]):\n", " out = layers.Conv2D(1, KERNEL, padding='same')(feature)\n", " outputs.append(out)\n", " \n", " return Model(inputs, outputs, name=name)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "*===================== Data Loading =====================*\n", "\n" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "def load_images(path):\n", " images = []\n", " for p in pathlib.Path(path).glob('*.*'):\n", " try:\n", " img = cv2.imread(str(p))\n", " if img is not None:\n", " img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n", " img = cv2.resize(img, IMAGE_SHAPE[:2])\n", " img = img.astype(np.float32) / 255.0\n", " images.append(img)\n", " except Exception as e:\n", " print(f\"Error loading image {p}: {e}\")\n", " return np.array(images)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This function is responsible for loading and balancing two different medical imaging datasets: CT scans and MRI scans. The goal is to ensure that both datasets contain the same number of images to avoid class imbalance in training.\n", "\n", "\n", "📌 What Does This Function Do?\n", "\n", "Loads CT scans from the given directory.\n", "Loads MRI scans from the given directory.\n", "Finds the smaller dataset (CT or MRI) and trims the larger one to match its size.\n", "Returns balanced datasets with the same number of images.\n", "\n", "1️⃣ Loading CT Scans:\n", "\n", " Prints \"Loading CT scans...\" to inform the user.\n", "\n", " Calls load_images(ct_path), a function (likely defined elsewhere) that reads images from the directory specified by ct_path.\n", "\n", " Stores the loaded images in ct_scans.\n", "\n", "2️⃣ Loading MRI Scans\n", "\n", " Prints \"Loading MRI scans...\" to indicate MRI loading.\n", "\n", " Calls load_images(mri_path), which loads images from mri_path.\n", "\n", " Stores the MRI images in mri_scans.\n", "\n", "3️⃣ Balancing the Datasets\n", "\n", " Computes the minimum length between the two datasets.\n", "\n", " Ensures that the dataset with more images is trimmed to match the smaller one.\n", "\n", " Computes the minimum length between the two datasets.\n", "\n", " Ensures that the dataset with more images is trimmed to match the smaller one.\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", " \n", "\n" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "\n", "def load_and_balance_datasets(ct_path, mri_path):\n", " print(\"Loading CT scans...\")\n", " ct_scans = load_images(ct_path)\n", " print(\"Loading MRI scans...\")\n", " mri_scans = load_images(mri_path)\n", " \n", " min_length = min(len(ct_scans), len(mri_scans))\n", " ct_scans = ct_scans[:min_length]\n", " mri_scans = mri_scans[:min_length]\n", " \n", " print(f\"Balanced datasets to {min_length} images each\")\n", " return ct_scans, mri_scans" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "*Training Setup - Detailed Explanation*\n", "\n", "This block of code sets up the models and optimizers required for training a CycleGAN for CT ↔ MRI image translation. Let’s break it down step by step.\n", "\n", "📌 What Does This Code Do?\n", " Builds the generator models (CT → MRI and MRI → CT).\n", "\n", " Builds the discriminator models for CT and MRI.\n", "\n", " Creates optimizers for training the generators and discriminators.\n", "\n", " Initializes model variables (trainable parameters for both generators and discriminators).\n", "\n", " Builds optimizers using the trainable variables.\n", "\n", "1️⃣ Building the Generator Models\n", "\n", " build_generator(name): This function (explained earlier) builds a U-Net-based Variational Autoencoder (VAE) generator.\n", "\n", " g_ct_mri: The generator that converts CT scans → MRI images.\n", "\n", " g_mri_ct: The generator that converts MRI images → CT scans\n", "\n", "2️⃣ Building the Discriminator Models\n", "\n", " build_discriminator(name): This function (explained earlier) builds the discriminators to differentiate real and fake images.\n", "\n", " d_ct: The discriminator that distinguishes real CT scans from fake ones.\n", "\n", " d_mri: The discriminator that distinguishes real MRI scans from fake ones.\n", "\n", "\n", "3️⃣ Creating Optimizers\n", "\n", " g_opt: Optimizer for training both generators.\n", "\n", " d_opt: Optimizer for training both discriminators.\n", "\n", " Uses RMSprop as the optimizer.\n", "\n", " The learning rate (LEARNING_RATE) controls the step size for updates.\n", "\n", " Weight decay (WEIGHT_DECAY) prevents overfitting by penalizing large weights.\n", "\n", "4️⃣ Initializing Model Variables\n", "\n", " g_vars: Stores all trainable variables (weights & biases) of both generators.\n", " d_vars: Stores all trainable variables of both discriminators.\n", "\n", " 📝 Why store trainable variables separately?\n", "\n", " Since generators and discriminators have separate losses, they need to be updated separately.\n", "\n", "5️⃣ Building Optimizers with Model Variables\n", "\n", " g_opt.build(g_vars): Tells TensorFlow that g_opt will optimize generator variables.\n", "\n", " d_opt.build(d_vars): Tells TensorFlow that d_opt will optimize discriminator variables.\n", "\n", " 📝 Why explicitly build the optimizers?\n", "\n", " In Eager Execution mode, TensorFlow automatically tracks variables.\n", " \n", " However, explicitly calling build() can help with performance optimization.\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From d:\\VS CODE\\Web Dev\\Projects\\Image2Image\\image\\lib\\site-packages\\keras\\src\\backend.py:1398: The name tf.executing_eagerly_outside_functions is deprecated. Please use tf.compat.v1.executing_eagerly_outside_functions instead.\n", "\n", "WARNING:tensorflow:From d:\\VS CODE\\Web Dev\\Projects\\Image2Image\\image\\lib\\site-packages\\keras\\src\\layers\\pooling\\max_pooling2d.py:161: The name tf.nn.max_pool is deprecated. Please use tf.nn.max_pool2d instead.\n", "\n" ] }, { "ename": "NameError", "evalue": "name 'Sampling' is not defined", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)", "Cell \u001b[1;32mIn[12], line 3\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[38;5;66;03m# ===================== Training Setup =====================\u001b[39;00m\n\u001b[0;32m 2\u001b[0m \u001b[38;5;66;03m# Build models\u001b[39;00m\n\u001b[1;32m----> 3\u001b[0m g_ct_mri \u001b[38;5;241m=\u001b[39m \u001b[43mbuild_generator\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mCT_to_MRI\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[0;32m 4\u001b[0m g_mri_ct \u001b[38;5;241m=\u001b[39m build_generator(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mMRI_to_CT\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m 5\u001b[0m d_ct \u001b[38;5;241m=\u001b[39m build_discriminator(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mD_CT\u001b[39m\u001b[38;5;124m'\u001b[39m)\n", "Cell \u001b[1;32mIn[7], line 17\u001b[0m, in \u001b[0;36mbuild_generator\u001b[1;34m(name)\u001b[0m\n\u001b[0;32m 15\u001b[0m z_mean \u001b[38;5;241m=\u001b[39m layers\u001b[38;5;241m.\u001b[39mDense(LATENT_DIM, name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mz_mean_\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m_\u001b[39m\u001b[38;5;124m'\u001b[39m)[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)(x)\n\u001b[0;32m 16\u001b[0m z_log_var \u001b[38;5;241m=\u001b[39m layers\u001b[38;5;241m.\u001b[39mDense(LATENT_DIM, name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mz_log_var_\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m_\u001b[39m\u001b[38;5;124m'\u001b[39m)[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)(x)\n\u001b[1;32m---> 17\u001b[0m z \u001b[38;5;241m=\u001b[39m \u001b[43mSampling\u001b[49m()([z_mean, z_log_var])\n\u001b[0;32m 19\u001b[0m \u001b[38;5;66;03m# Reshape for decoder\u001b[39;00m\n\u001b[0;32m 20\u001b[0m x \u001b[38;5;241m=\u001b[39m layers\u001b[38;5;241m.\u001b[39mDense(\u001b[38;5;241m2\u001b[39m \u001b[38;5;241m*\u001b[39m \u001b[38;5;241m2\u001b[39m \u001b[38;5;241m*\u001b[39m FILTERS\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m64\u001b[39m)(z)\n", "\u001b[1;31mNameError\u001b[0m: name 'Sampling' is not defined" ] } ], "source": [ "# ===================== Training Setup =====================\n", "# Build models\n", "g_ct_mri = build_generator('CT_to_MRI')\n", "g_mri_ct = build_generator('MRI_to_CT')\n", "d_ct = build_discriminator('D_CT')\n", "d_mri = build_discriminator('D_MRI')\n", "\n", "# Create optimizers\n", "g_opt = tf.keras.optimizers.RMSprop(learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY)\n", "d_opt = tf.keras.optimizers.RMSprop(learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY)\n", "\n", "# Initialize model variables\n", "g_vars = g_ct_mri.trainable_variables + g_mri_ct.trainable_variables\n", "d_vars = d_ct.trainable_variables + d_mri.trainable_variables\n", "\n", "# Build optimizers\n", "g_opt.build(g_vars)\n", "d_opt.build(d_vars)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "*Explanation of train_step Function in CycleGAN with Variational Autoencoder (VAE)*\n", "\n", "This function performs one training step for the CycleGAN with VAE-style latent representations. It does the following:\n", "\n", "Generates fake images using the generators.\n", "\n", "Evaluates the fake and real images using the discriminators.\n", "\n", "Computes the loss functions for both generators and discriminators.\n", "\n", "Computes gradients and updates the model parameters.\n", "\n", "1️⃣ Forward Pass - Generate Fake Images\n", "\n", " g_ct_mri(real_ct): Translates CT → Fake MRI and produces:\n", " \n", " fake_mri: The generated MRI image.\n", " z_mean_fwd, z_log_var_fwd: Latent variables (from the Variational Autoencoder).\n", " g_mri_ct(real_mri): Translates MRI → Fake CT with similar outputs.\n", "\n", " 📝 Why store z_mean and z_log_var?\n", "\n", " These come from the VAE latent space and are used for the KL divergence loss.\n", "\n", "2️⃣ Compute Discriminator Outputs\n", "\n", " d_ct(real_ct): Discriminator’s prediction for real CT images.\n", "\n", " d_ct(fake_ct): Discriminator’s prediction for fake CT images.\n", "\n", " d_mri(real_mri): Discriminator’s prediction for real MRI images.\n", "\n", " d_mri(fake_mri): Discriminator’s prediction for fake MRI images.\n", "\n", " 📝 Goal of Discriminators?\n", "\n", "\n", " Real images should be classified close to 1.\n", "\n", " Fake images should be classified close to 0.\n", "\n", "3️⃣ Compute Discriminator Losses\n", "\n", " Uses Least Squares GAN (LSGAN) loss:\n", "\n", " For real images: (real - 1)^2 → Encourages real images to be classified as 1.\n", "\n", " For fake images: fake^2 → Encourages fake images to be classified as 0.\n", "\n", " sum([...]): If there are multiple output layers in the discriminator, we sum their losses.\n", "\n", " 📝 Why LSGAN loss?\n", "\n", " Helps stabilize training compared to standard GAN loss.\n", "\n", "\n", "4️⃣ Cycle Consistency Loss (CycleGAN Component)\n", "\n", " cycled_ct = g_mri_ct(fake_mri): The fake MRI is translated back to CT.\n", "\n", " cycled_mri = g_ct_mri(fake_ct): The fake CT is translated back to MRI.\n", "\n", " 📝 Why cycle consistency?\n", "\n", " The network should learn round-trip consistency:\n", "\n", " CT → Fake MRI → CT (should look like original CT)\n", "\n", " MRI → Fake CT → MRI (should look like original MRI)\n", "\n", "5️⃣ KL Divergence Loss (VAE Component)\n", "\n", " This is the KL divergence loss from VAE:\n", "\n", " Encourages the latent space to follow a Gaussian distribution.\n", "\n", " Prevents mode collapse.\n", "\n", " 📝 Why add KL divergence loss?\n", "\n", " Regularizes the latent space so the generator produces diverse outputs.\n", "\n", "6️⃣ Compute Generator Losses\n", "\n", " The generator wants fake images to be classified as real (1), so we use:\n", "\n", " (fake - 1)^2 → Fake images should be close to 1.\n", "\n", " Cycle consistency loss: L1 loss (|original - reconstructed|).\n", "\n", " Encourages faithful reconstructions.\n", "\n", "\n", " Final generator loss combines:\n", "\n", " Adversarial loss (GAN loss).\n", "\n", " Cycle consistency loss (weighted by 10 for stronger enforcement).\n", "\n", " KL divergence loss (weighted by 0.5).\n", "\n", " \n", "7️⃣ Compute Total Discriminator Loss\n", "\n", "Adds both discriminator losses.\n", "\n", "8️⃣ Compute Gradients & Update Model Parameters\n", "\n", " Computes gradients of discriminator loss (d_total_loss).\n", "\n", " Updates discriminator weights (d_vars).\n", "\n", " Computes gradients of generator loss (g_total_loss).\n", "\n", " Updates generator weights (g_vars).\n", "\n", "📝 Why use tf.GradientTape(persistent=True)?\n", "\n", " We need gradients twice (once for discriminators, once for generators).\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "@tf.function\n", "def train_step(real_ct, real_mri):\n", " with tf.GradientTape(persistent=True) as tape:\n", " # Forward passes\n", " fake_mri, z_mean_fwd, z_log_var_fwd = g_ct_mri(real_ct, training=True)\n", " fake_ct, z_mean_bwd, z_log_var_bwd = g_mri_ct(real_mri, training=True)\n", " \n", " # Discriminator outputs\n", " d_real_ct = d_ct(real_ct, training=True)\n", " d_fake_ct = d_ct(fake_ct, training=True)\n", " d_real_mri = d_mri(real_mri, training=True)\n", " d_fake_mri = d_mri(fake_mri, training=True)\n", " \n", " # Discriminator losses\n", " d_ct_loss = sum([tf.reduce_mean((real - 1)**2) + tf.reduce_mean(fake**2) \n", " for real, fake in zip(d_real_ct, d_fake_ct)])\n", " d_mri_loss = sum([tf.reduce_mean((real - 1)**2) + tf.reduce_mean(fake**2) \n", " for real, fake in zip(d_real_mri, d_fake_mri)])\n", " \n", " # Cycle consistency\n", " cycled_ct, _, _ = g_mri_ct(fake_mri, training=True)\n", " cycled_mri, _, _ = g_ct_mri(fake_ct, training=True)\n", " \n", " # KL Divergence\n", " kl_fwd = -0.5 * tf.reduce_mean(1 + z_log_var_fwd - tf.square(z_mean_fwd) - tf.exp(z_log_var_fwd))\n", " kl_bwd = -0.5 * tf.reduce_mean(1 + z_log_var_bwd - tf.square(z_mean_bwd) - tf.exp(z_log_var_bwd))\n", " \n", " # Generator losses\n", " g_adv_loss = sum([tf.reduce_mean((fake - 1)**2) for fake in d_fake_mri + d_fake_ct])\n", " g_cycle_loss = (tf.reduce_mean(tf.abs(real_ct - cycled_ct)) + \n", " tf.reduce_mean(tf.abs(real_mri - cycled_mri)))\n", " g_total_loss = g_adv_loss + 10 * g_cycle_loss + 0.5 * (kl_fwd + kl_bwd)\n", " \n", " # Total discriminator loss\n", " d_total_loss = d_ct_loss + d_mri_loss\n", " \n", " # Update discriminators\n", " d_grads = tape.gradient(d_total_loss, d_vars)\n", " d_opt.apply_gradients(zip(d_grads, d_vars))\n", " \n", " # Update generators\n", " g_grads = tape.gradient(g_total_loss, g_vars)\n", " g_opt.apply_gradients(zip(g_grads, g_vars))\n", " \n", " return {\n", " 'd_ct': d_ct_loss,\n", " 'd_mri': d_mri_loss,\n", " 'g_total': g_total_loss,\n", " 'fake_mri': fake_mri,\n", " 'fake_ct': fake_ct\n", " }" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This code defines the main training loop for a CycleGAN-based model that translates between CT and MRI images. It consists of data preparation, training iteration, progress tracking, and model saving. Below is a step-by-step breakdown:\n", "\n", "\n", "1. Create Progress Directory\n", "\n", "The script creates a directory named progress/ inside Kaggle's working directory.\n", "\n", "This directory will store progress images showing how well the model is learning over time.\n", "\n", "2. Load and Balance the Datasets\n", "\n", "Calls load_and_balance_datasets() to load CT and MRI images from the dataset folders.\n", "\n", "Ensures both datasets have the same number of images by truncating the larger set.\n", "\n", "3. Create TensorFlow Dataset for Training\n", "\n", "Creates a TensorFlow dataset from the loaded images.\n", "\n", "Shuffles the dataset to introduce randomness and prevent overfitting.\n", "\n", "Batches the dataset to process multiple images in parallel during training.\n", "\n", "\n", "4. Training Loop\n", "\n", "Starts iterating over epochs (EPOCHS defines the total number of passes over the dataset).\n", "\n", "Iterates through mini-batches of CT and MRI scans using train_dataset.\n", "\n", "5. Train the Model (Forward & Backward Pass)\n", "\n", "Calls train_step(ct_batch, mri_batch), which:\n", "\n", " Generates fake MRI from CT (G_CT→MRI) and fake CT from MRI (G_MRI→CT).\n", "\n", " Passes both real and fake images through the discriminators (D_CT and D_MRI).\n", "\n", " Computes adversarial losses, cycle consistency loss, and KL divergence.\n", "\n", " Updates the discriminators (D_CT, D_MRI) and generators (G_CT→MRI, G_MRI→CT).\n", "\n", "Stores the loss values (d_ct_loss, d_mri_loss, g_total_loss) and the generated images.\n", "\n", "6. Print Losses for Monitoring\n", "\n", "Every 10 batches, prints:\n", "\n", "D_CT: Discriminator loss for CT.\n", "\n", "D_MRI: Discriminator loss for MRI.\n", "\n", "G: Total generator loss.\n", "\n", "This helps monitor model performance during training.\n", "\n", "7. Save Sample Images for Progress Tracking\n", "\n", "Every 100 batches, saves progress images to progress/.\n", "\n", "Displays real CT & MRI images alongside their fake counterparts generated by the model.\n", "\n", "Helps visually track improvements in image quality over time.\n", "\n" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "ename": "NameError", "evalue": "name 'os' is not defined", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)", "Cell \u001b[1;32mIn[14], line 4\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[38;5;66;03m# ===================== Main Training Loop =====================\u001b[39;00m\n\u001b[0;32m 2\u001b[0m \u001b[38;5;66;03m# Create progress directory if it doesn't exist\u001b[39;00m\n\u001b[0;32m 3\u001b[0m progress_dir \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m/kaggle/working/progress\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m----> 4\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[43mos\u001b[49m\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mexists(progress_dir):\n\u001b[0;32m 5\u001b[0m os\u001b[38;5;241m.\u001b[39mmakedirs(progress_dir)\n\u001b[0;32m 7\u001b[0m \u001b[38;5;66;03m# Load and prepare data\u001b[39;00m\n", "\u001b[1;31mNameError\u001b[0m: name 'os' is not defined" ] } ], "source": [ "\n", "# ===================== Main Training Loop =====================\n", "# Create progress directory if it doesn't exist\n", "progress_dir = '/kaggle/working/progress'\n", "if not os.path.exists(progress_dir):\n", " os.makedirs(progress_dir)\n", "\n", "# Load and prepare data\n", "print(\"Loading datasets...\")\n", "ct_scans, mri_scans = load_and_balance_datasets('/kaggle/input/ct-to-mri-cgan/Dataset/images/trainA', \n", " '/kaggle/input/ct-to-mri-cgan/Dataset/images/trainB')\n", "\n", "# Create TensorFlow dataset\n", "train_dataset = tf.data.Dataset.from_tensor_slices((ct_scans, mri_scans))\n", "train_dataset = train_dataset.shuffle(buffer_size=len(ct_scans)).batch(BATCH_SIZE)\n", "# Training loop\n", "print(\"Starting training...\")\n", "for epoch in range(EPOCHS):\n", " for batch_idx, (ct_batch, mri_batch) in enumerate(train_dataset):\n", " results = train_step(ct_batch, mri_batch)\n", " \n", " if batch_idx % 10 == 0:\n", " print(f\"Epoch {epoch}, Batch {batch_idx}: \"\n", " f\"D_CT={float(results['d_ct']):.4f}, \"\n", " f\"D_MRI={float(results['d_mri']):.4f}, \"\n", " f\"G={float(results['g_total']):.4f}\")\n", " \n", " # Save sample images every 100 batches\n", " if batch_idx % 100 == 0:\n", " fig, axes = plt.subplots(2, 2, figsize=(10, 10))\n", " \n", " # Real CT and Fake MRI\n", " axes[0,0].imshow(ct_batch[0].numpy())\n", " axes[0,0].set_title(\"Real CT\")\n", " axes[0,0].axis('off')\n", " \n", " axes[0,1].imshow(results['fake_mri'][0].numpy())\n", " axes[0,1].set_title(\"Fake MRI\")\n", " axes[0,1].axis('off')\n", " \n", " # Real MRI and Fake CT\n", " axes[1,0].imshow(mri_batch[0].numpy())\n", " axes[1,0].set_title(\"Real MRI\")\n", " axes[1,0].axis('off')\n", " \n", " axes[1,1].imshow(results['fake_ct'][0].numpy())\n", " axes[1,1].set_title(\"Fake CT\")\n", " axes[1,1].axis('off')\n", " \n", " plt.tight_layout()\n", " plt.savefig(f'progress/epoch_{epoch}_batch_{batch_idx}.png')\n", " plt.close()\n", " \n", " # Save models after each epoch\n", " save_models(g_ct_mri, g_mri_ct, epoch)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "def translate_image(model_path, image_path, output_path, mode='ct_to_mri'):\n", " \"\"\"\n", " Translate a single image using the trained model\n", " \n", " Parameters:\n", " model_path: Path to the saved model\n", " image_path: Path to the input image\n", " output_path: Path to save the translated image\n", " mode: 'ct_to_mri' or 'mri_to_ct'\n", " \"\"\"\n", " # Load model\n", " print(f\"Loading model from {model_path}\")\n", " model = tf.keras.models.load_model(model_path, \n", " custom_objects={'Sampling': Sampling})\n", " \n", " # Load and preprocess image\n", " input_image = load_and_preprocess_image(image_path)\n", " \n", " # Generate translation\n", " print(\"Generating translation...\")\n", " translated_image, _, _ = model(input_image, training=False)\n", " \n", " # Convert to numpy and denormalize\n", " translated_image = translated_image.numpy()[0] * 255\n", " translated_image = translated_image.astype(np.uint8)\n", " \n", " # Save the result\n", " print(f\"Saving translated image to {output_path}\")\n", " plt.figure(figsize=(10, 5))\n", " \n", " plt.subplot(1, 2, 1)\n", " plt.title(\"Input Image\")\n", " plt.imshow(input_image[0])\n", " plt.axis('off')\n", " \n", " plt.subplot(1, 2, 2)\n", " plt.title(\"Translated Image\")\n", " plt.imshow(translated_image)\n", " plt.axis('off')\n", " \n", " plt.tight_layout()\n", " plt.savefig(output_path)\n", " plt.close()\n", " \n", " return translated_image\n", "'''\n", "# Example usage of the translation function\n", "def example_translation():\n", " \"\"\"Example of how to use the translation function\"\"\"\n", " # Paths\n", " ct_to_mri_model = 'saved_models/ct_to_mri_epoch_1000'\n", " mri_to_ct_model = 'saved_models/mri_to_ct_epoch_1000'\n", " \n", " # CT to MRI translation\n", " input_ct = 'path/to/your/ct_image.jpg'\n", " output_mri = 'results/translated_mri.png'\n", " translated_mri = translate_image(ct_to_mri_model, input_ct, output_mri, \n", " mode='ct_to_mri')\n", " \n", " # MRI to CT translation\n", " input_mri = 'path/to/your/mri_image.jpg'\n", " output_ct = 'results/translated_ct.png'\n", " translated_ct = translate_image(mri_to_ct_model, input_mri, output_ct, \n", " mode='mri_to_ct')'''" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "*Complete code in Single Block*" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import tensorflow as tf\n", "from tensorflow.keras import layers, Model\n", "import numpy as np\n", "import cv2\n", "import pathlib\n", "import matplotlib.pyplot as plt\n", "import tensorflow_probability as tfp\n", "\n", "tfd = tfp.distributions\n", "\n", "# ===================== Configuration =====================\n", "IMAGE_SHAPE = (256, 256, 3)\n", "LATENT_DIM = 256\n", "FILTERS = 16\n", "KERNEL = 3\n", "LEARNING_RATE = 0.0001\n", "WEIGHT_DECAY = 6e-8\n", "BATCH_SIZE = 1\n", "EPOCHS = 10\n", "\n", "# ===================== Architecture Components =====================\n", "class Sampling(layers.Layer):\n", " def call(self, inputs):\n", " z_mean, z_log_var = inputs\n", " batch = tf.shape(z_mean)[0]\n", " dim = tf.shape(z_mean)[1]\n", " epsilon = tf.random.normal(shape=(batch, dim))\n", " return z_mean + tf.exp(0.5 * z_log_var) * epsilon\n", "\n", "def residual_block(inputs, filters, use_norm=True):\n", " x = layers.Conv2D(filters, KERNEL, padding='same')(inputs)\n", " x = layers.LeakyReLU(alpha=0.2)(x)\n", " if use_norm:\n", " x = layers.GroupNormalization(groups=1)(x)\n", " x = layers.Conv2D(filters, KERNEL, padding='same')(x)\n", " x = layers.LeakyReLU(alpha=0.2)(x)\n", " if use_norm:\n", " x = layers.GroupNormalization(groups=1)(x)\n", " shortcut = layers.Conv2D(filters, 1, padding='same')(inputs)\n", " return layers.maximum([x, shortcut])\n", "\n", "def encoder_block(inputs, filters, use_norm=True):\n", " x = residual_block(inputs, filters, use_norm)\n", " skip = x\n", " x = layers.MaxPooling2D()(x)\n", " return x, skip\n", "\n", "def decoder_block(inputs, skip, filters, use_norm=True):\n", " x = layers.Conv2DTranspose(filters, KERNEL, strides=2, padding='same')(inputs)\n", " x = layers.maximum([x, skip])\n", " x = residual_block(x, filters, use_norm)\n", " return x\n", "\n", "# ===================== Generator =====================\n", "def build_generator(name):\n", " inputs = layers.Input(IMAGE_SHAPE)\n", " \n", " # Encoder\n", " e1, s1 = encoder_block(inputs, FILTERS)\n", " e2, s2 = encoder_block(e1, FILTERS*2)\n", " e3, s3 = encoder_block(e2, FILTERS*4)\n", " e4, s4 = encoder_block(e3, FILTERS*8)\n", " e5, s5 = encoder_block(e4, FILTERS*16)\n", " e6, s6 = encoder_block(e5, FILTERS*32)\n", " e7, s7 = encoder_block(e6, FILTERS*64)\n", " \n", " # Latent Space\n", " x = layers.Flatten()(e7)\n", " z_mean = layers.Dense(LATENT_DIM, name=f\"z_mean_{name.split('_')[-1]}\")(x)\n", " z_log_var = layers.Dense(LATENT_DIM, name=f\"z_log_var_{name.split('_')[-1]}\")(x)\n", " z = Sampling()([z_mean, z_log_var])\n", " \n", " # Reshape for decoder\n", " x = layers.Dense(2 * 2 * FILTERS*64)(z)\n", " x = layers.Reshape((2, 2, FILTERS*64))(x)\n", " \n", " # Decoder\n", " d0 = decoder_block(x, s7, FILTERS*64)\n", " d1 = decoder_block(d0, s6, FILTERS*32)\n", " d2 = decoder_block(d1, s5, FILTERS*16)\n", " d3 = decoder_block(d2, s4, FILTERS*8)\n", " d4 = decoder_block(d3, s3, FILTERS*4)\n", " d5 = decoder_block(d4, s2, FILTERS*2)\n", " d6 = decoder_block(d5, s1, FILTERS)\n", " \n", " outputs = layers.Conv2D(3, KERNEL, activation='sigmoid', padding='same')(d6)\n", " return Model(inputs, [outputs, z_mean, z_log_var], name=name)\n", "\n", "# ===================== Discriminator =====================\n", "def build_discriminator(name):\n", " inputs = layers.Input(IMAGE_SHAPE)\n", " \n", " # Feature extraction\n", " x = inputs\n", " features = []\n", " \n", " # Initial convolution\n", " x = layers.Conv2D(FILTERS, KERNEL, padding='same')(x)\n", " x = layers.LeakyReLU(alpha=0.2)(x)\n", " features.append(x)\n", " \n", " # Downsampling blocks\n", " filter_sizes = [FILTERS*2, FILTERS*4, FILTERS*8, FILTERS*16, FILTERS*32, FILTERS*64]\n", " for filters in filter_sizes:\n", " x, _ = encoder_block(x, filters, use_norm=False)\n", " features.append(x)\n", " \n", " # Multi-scale outputs\n", " outputs = []\n", " for i, feature in enumerate(features[-4:]):\n", " out = layers.Conv2D(1, KERNEL, padding='same')(feature)\n", " outputs.append(out)\n", " \n", " return Model(inputs, outputs, name=name)\n", "\n", "# ===================== Data Loading =====================\n", "def load_images(path):\n", " images = []\n", " for p in pathlib.Path(path).glob('*.*'):\n", " try:\n", " img = cv2.imread(str(p))\n", " if img is not None:\n", " img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n", " img = cv2.resize(img, IMAGE_SHAPE[:2])\n", " img = img.astype(np.float32) / 255.0\n", " images.append(img)\n", " except Exception as e:\n", " print(f\"Error loading image {p}: {e}\")\n", " return np.array(images)\n", "\n", "def load_and_balance_datasets(ct_path, mri_path):\n", " print(\"Loading CT scans...\")\n", " ct_scans = load_images(ct_path)\n", " print(\"Loading MRI scans...\")\n", " mri_scans = load_images(mri_path)\n", " \n", " min_length = min(len(ct_scans), len(mri_scans))\n", " ct_scans = ct_scans[:min_length]\n", " mri_scans = mri_scans[:min_length]\n", " \n", " print(f\"Balanced datasets to {min_length} images each\")\n", " return ct_scans, mri_scans\n", "\n", "# ===================== Training Setup =====================\n", "# Build models\n", "g_ct_mri = build_generator('CT_to_MRI')\n", "g_mri_ct = build_generator('MRI_to_CT')\n", "d_ct = build_discriminator('D_CT')\n", "d_mri = build_discriminator('D_MRI')\n", "\n", "# Create optimizers\n", "g_opt = tf.keras.optimizers.RMSprop(learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY)\n", "d_opt = tf.keras.optimizers.RMSprop(learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY)\n", "\n", "# Initialize model variables\n", "g_vars = g_ct_mri.trainable_variables + g_mri_ct.trainable_variables\n", "d_vars = d_ct.trainable_variables + d_mri.trainable_variables\n", "\n", "# Build optimizers\n", "g_opt.build(g_vars)\n", "d_opt.build(d_vars)\n", "\n", "# ===================== Training Function =====================\n", "@tf.function\n", "def train_step(real_ct, real_mri):\n", " with tf.GradientTape(persistent=True) as tape:\n", " # Forward passes\n", " fake_mri, z_mean_fwd, z_log_var_fwd = g_ct_mri(real_ct, training=True)\n", " fake_ct, z_mean_bwd, z_log_var_bwd = g_mri_ct(real_mri, training=True)\n", " \n", " # Discriminator outputs\n", " d_real_ct = d_ct(real_ct, training=True)\n", " d_fake_ct = d_ct(fake_ct, training=True)\n", " d_real_mri = d_mri(real_mri, training=True)\n", " d_fake_mri = d_mri(fake_mri, training=True)\n", " \n", " # Discriminator losses\n", " d_ct_loss = sum([tf.reduce_mean((real - 1)**2) + tf.reduce_mean(fake**2) \n", " for real, fake in zip(d_real_ct, d_fake_ct)])\n", " d_mri_loss = sum([tf.reduce_mean((real - 1)**2) + tf.reduce_mean(fake**2) \n", " for real, fake in zip(d_real_mri, d_fake_mri)])\n", " \n", " # Cycle consistency\n", " cycled_ct, _, _ = g_mri_ct(fake_mri, training=True)\n", " cycled_mri, _, _ = g_ct_mri(fake_ct, training=True)\n", " \n", " # KL Divergence\n", " kl_fwd = -0.5 * tf.reduce_mean(1 + z_log_var_fwd - tf.square(z_mean_fwd) - tf.exp(z_log_var_fwd))\n", " kl_bwd = -0.5 * tf.reduce_mean(1 + z_log_var_bwd - tf.square(z_mean_bwd) - tf.exp(z_log_var_bwd))\n", " \n", " # Generator losses\n", " g_adv_loss = sum([tf.reduce_mean((fake - 1)**2) for fake in d_fake_mri + d_fake_ct])\n", " g_cycle_loss = (tf.reduce_mean(tf.abs(real_ct - cycled_ct)) + \n", " tf.reduce_mean(tf.abs(real_mri - cycled_mri)))\n", " g_total_loss = g_adv_loss + 10 * g_cycle_loss + 0.5 * (kl_fwd + kl_bwd)\n", " \n", " # Total discriminator loss\n", " d_total_loss = d_ct_loss + d_mri_loss\n", " \n", " # Update discriminators\n", " d_grads = tape.gradient(d_total_loss, d_vars)\n", " d_opt.apply_gradients(zip(d_grads, d_vars))\n", " \n", " # Update generators\n", " g_grads = tape.gradient(g_total_loss, g_vars)\n", " g_opt.apply_gradients(zip(g_grads, g_vars))\n", " \n", " return {\n", " 'd_ct': d_ct_loss,\n", " 'd_mri': d_mri_loss,\n", " 'g_total': g_total_loss,\n", " 'fake_mri': fake_mri,\n", " 'fake_ct': fake_ct\n", " }\n", "\n", "\n", "\n", "import os\n", "\n", "def save_models(g_ct_mri, g_mri_ct, epoch, model_dir='/kaggle/working/saved_models'):\n", " \"\"\"Save models in HDF5 format after each epoch\"\"\"\n", " if not os.path.exists(model_dir):\n", " os.makedirs(model_dir)\n", " \n", " # Save as .h5 files\n", " ct_path = os.path.join(model_dir, f'ct_to_mri_epoch_{epoch}.h5')\n", " mri_path = os.path.join(model_dir, f'mri_to_ct_epoch_{epoch}.h5')\n", " \n", " g_ct_mri.save(ct_path)\n", " g_mri_ct.save(mri_path)\n", " print(f\"Models saved: {ct_path} and {mri_path}\")\n", "\n", "\n", "# ===================== Main Training Loop =====================\n", "# Create progress directory if it doesn't exist\n", "progress_dir = '/kaggle/working/progress'\n", "if not os.path.exists(progress_dir):\n", " os.makedirs(progress_dir)\n", "\n", "# Load and prepare data\n", "print(\"Loading datasets...\")\n", "ct_scans, mri_scans = load_and_balance_datasets('/kaggle/input/ct-to-mri-cgan/Dataset/images/trainA', \n", " '/kaggle/input/ct-to-mri-cgan/Dataset/images/trainB')\n", "\n", "# Create TensorFlow dataset\n", "train_dataset = tf.data.Dataset.from_tensor_slices((ct_scans, mri_scans))\n", "train_dataset = train_dataset.shuffle(buffer_size=len(ct_scans)).batch(BATCH_SIZE)\n", "# Training loop\n", "print(\"Starting training...\")\n", "for epoch in range(EPOCHS):\n", " for batch_idx, (ct_batch, mri_batch) in enumerate(train_dataset):\n", " results = train_step(ct_batch, mri_batch)\n", " \n", " if batch_idx % 10 == 0:\n", " print(f\"Epoch {epoch}, Batch {batch_idx}: \"\n", " f\"D_CT={float(results['d_ct']):.4f}, \"\n", " f\"D_MRI={float(results['d_mri']):.4f}, \"\n", " f\"G={float(results['g_total']):.4f}\")\n", " \n", " # Save sample images every 100 batches\n", " if batch_idx % 100 == 0:\n", " fig, axes = plt.subplots(2, 2, figsize=(10, 10))\n", " \n", " # Real CT and Fake MRI\n", " axes[0,0].imshow(ct_batch[0].numpy())\n", " axes[0,0].set_title(\"Real CT\")\n", " axes[0,0].axis('off')\n", " \n", " axes[0,1].imshow(results['fake_mri'][0].numpy())\n", " axes[0,1].set_title(\"Fake MRI\")\n", " axes[0,1].axis('off')\n", " \n", " # Real MRI and Fake CT\n", " axes[1,0].imshow(mri_batch[0].numpy())\n", " axes[1,0].set_title(\"Real MRI\")\n", " axes[1,0].axis('off')\n", " \n", " axes[1,1].imshow(results['fake_ct'][0].numpy())\n", " axes[1,1].set_title(\"Fake CT\")\n", " axes[1,1].axis('off')\n", " \n", " plt.tight_layout()\n", " plt.savefig(f'progress/epoch_{epoch}_batch_{batch_idx}.png')\n", " plt.close()\n", " \n", " # Save models after each epoch\n", " save_models(g_ct_mri, g_mri_ct, epoch)\n", "\n", "def load_and_preprocess_image(image_path):\n", " \"\"\"Load and preprocess a single image for inference\"\"\"\n", " # Read image\n", " img = cv2.imread(image_path)\n", " if img is None:\n", " raise ValueError(f\"Could not load image from {image_path}\")\n", " \n", " # Convert BGR to RGB\n", " img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n", " \n", " # Resize to model's input size\n", " img = cv2.resize(img, (256, 256))\n", " \n", " # Normalize to [0, 1]\n", " img = img.astype(np.float32) / 255.0\n", " \n", " # Add batch dimension\n", " img = np.expand_dims(img, axis=0)\n", " \n", " return img\n", "\n", "def translate_image(model_path, image_path, output_path, mode='ct_to_mri'):\n", " \"\"\"\n", " Translate a single image using the trained model\n", " \n", " Parameters:\n", " model_path: Path to the saved model\n", " image_path: Path to the input image\n", " output_path: Path to save the translated image\n", " mode: 'ct_to_mri' or 'mri_to_ct'\n", " \"\"\"\n", " # Load model\n", " print(f\"Loading model from {model_path}\")\n", " model = tf.keras.models.load_model(model_path, \n", " custom_objects={'Sampling': Sampling})\n", " \n", " # Load and preprocess image\n", " input_image = load_and_preprocess_image(image_path)\n", " \n", " # Generate translation\n", " print(\"Generating translation...\")\n", " translated_image, _, _ = model(input_image, training=False)\n", " \n", " # Convert to numpy and denormalize\n", " translated_image = translated_image.numpy()[0] * 255\n", " translated_image = translated_image.astype(np.uint8)\n", " \n", " # Save the result\n", " print(f\"Saving translated image to {output_path}\")\n", " plt.figure(figsize=(10, 5))\n", " \n", " plt.subplot(1, 2, 1)\n", " plt.title(\"Input Image\")\n", " plt.imshow(input_image[0])\n", " plt.axis('off')\n", " \n", " plt.subplot(1, 2, 2)\n", " plt.title(\"Translated Image\")\n", " plt.imshow(translated_image)\n", " plt.axis('off')\n", " \n", " plt.tight_layout()\n", " plt.savefig(output_path)\n", " plt.close()\n", " \n", " return translated_image\n", "'''\n", "# Example usage of the translation function\n", "def example_translation():\n", " \"\"\"Example of how to use the translation function\"\"\"\n", " # Paths\n", " ct_to_mri_model = 'saved_models/ct_to_mri_epoch_1000'\n", " mri_to_ct_model = 'saved_models/mri_to_ct_epoch_1000'\n", " \n", " # CT to MRI translation\n", " input_ct = 'path/to/your/ct_image.jpg'\n", " output_mri = 'results/translated_mri.png'\n", " translated_mri = translate_image(ct_to_mri_model, input_ct, output_mri, \n", " mode='ct_to_mri')\n", " \n", " # MRI to CT translation\n", " input_mri = 'path/to/your/mri_image.jpg'\n", " output_ct = 'results/translated_ct.png'\n", " translated_ct = translate_image(mri_to_ct_model, input_mri, output_ct, \n", " mode='mri_to_ct')'''" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "image", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.11" } }, "nbformat": 4, "nbformat_minor": 2 }