Brain MRI Synthesis with Stable Diffusion (fine-tuned with dataset prompts)

This model is a fine-tuned version of Stable Diffusion v1-4 for prompt-conditioned synthesis of brain MRI FLAIR slices. It leverages latent diffusion and dataset-specific prompts to generate realistic 256x256 FLAIR scans with control over the source dataset's style or distribution.

πŸ” Prompt Conditioning

The model introduces three special prompt tokens corresponding to the dataset of origin. During training, each image was paired with a prompt indicating its source:

  • "SHIFTS FLAIR MRI"
  • "VH FLAIR MRI"
  • "WMH2017 FLAIR MRI"

These prompts were added as special tokens to the tokenizer, and their embeddings were fine-tuned alongside the U-Net, enabling dataset-specific synthesis.

🧠 Training Details

  • Base Model: CompVis/stable-diffusion-v1-4
  • Architecture: Latent Diffusion with U-Net + ResNet + Attention
  • Input resolution (latent): 32x32
  • Output resolution (decoded): 256x256 pixels
  • Datasets: SHIFTS, VH, and WMH2017 (FLAIR MRI slices)
  • Channels: 4 latent channels
  • Epochs: 50
  • Batch size: 8
  • Gradient accumulation: 4 steps
  • Optimizer: AdamW
    • Learning Rate: 1.0e-4
    • Betas: (0.95, 0.999)
    • Weight Decay: 1.0e-6
    • Epsilon: 1.0e-8
  • LR Scheduler: Cosine schedule with 500 warm-up steps
  • Noise Scheduler: DDPM with:
    • num_train_timesteps: 1000
    • beta_start: 0.0001
    • beta_end: 0.02
    • beta_schedule: "linear"
  • Mixed Precision: Disabled
  • Gradient Clipping: max norm 1.0
  • Hardware: NVIDIA A30 GPU with 4 dataloader workers

πŸ§ͺ Usage

You can use this model via the diffusers library for conditional generation:

from diffusers import DiffusionPipeline
import torch

# Load the model
pipe = DiffusionPipeline.from_pretrained("benetraco/latent_finetuning")
pipe.to("cuda")  # or "cpu"

# Generate a brain MRI image in SHIFTS style
prompt = "SHIFTS FLAIR MRI"
image = pipe(prompt=prompt, num_inference_steps=50, guidance_scale=2.0).images[0]

image.show()
Downloads last month
45
Safetensors
Model size
123M params
Tensor type
F32
Β·
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support