Brain MRI Synthesis with Stable Diffusion (fine-tuned with dataset prompts)
This model is a fine-tuned version of Stable Diffusion v1-4 for prompt-conditioned synthesis of brain MRI FLAIR slices. It leverages latent diffusion and dataset-specific prompts to generate realistic 256x256 FLAIR scans with control over the source dataset's style or distribution.
π Prompt Conditioning
The model introduces three special prompt tokens corresponding to the dataset of origin. During training, each image was paired with a prompt indicating its source:
"SHIFTS FLAIR MRI"
"VH FLAIR MRI"
"WMH2017 FLAIR MRI"
These prompts were added as special tokens to the tokenizer, and their embeddings were fine-tuned alongside the U-Net, enabling dataset-specific synthesis.
π§ Training Details
- Base Model:
CompVis/stable-diffusion-v1-4
- Architecture: Latent Diffusion with U-Net + ResNet + Attention
- Input resolution (latent): 32x32
- Output resolution (decoded): 256x256 pixels
- Datasets: SHIFTS, VH, and WMH2017 (FLAIR MRI slices)
- Channels: 4 latent channels
- Epochs: 50
- Batch size: 8
- Gradient accumulation: 4 steps
- Optimizer: AdamW
- Learning Rate:
1.0e-4
- Betas: (0.95, 0.999)
- Weight Decay:
1.0e-6
- Epsilon:
1.0e-8
- Learning Rate:
- LR Scheduler: Cosine schedule with 500 warm-up steps
- Noise Scheduler: DDPM with:
num_train_timesteps
: 1000beta_start
: 0.0001beta_end
: 0.02beta_schedule
: "linear"
- Mixed Precision: Disabled
- Gradient Clipping: max norm 1.0
- Hardware: NVIDIA A30 GPU with 4 dataloader workers
π§ͺ Usage
You can use this model via the diffusers
library for conditional generation:
from diffusers import DiffusionPipeline
import torch
# Load the model
pipe = DiffusionPipeline.from_pretrained("benetraco/latent_finetuning")
pipe.to("cuda") # or "cpu"
# Generate a brain MRI image in SHIFTS style
prompt = "SHIFTS FLAIR MRI"
image = pipe(prompt=prompt, num_inference_steps=50, guidance_scale=2.0).images[0]
image.show()
- Downloads last month
- 45
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
π
Ask for provider support