license: mit
tags:
- pytorch
- diffusers
- stable-diffusion
- latent-diffusion
- medical-imaging
- brain-mri
- multiple-sclerosis
- dataset-conditioning
#: Brain MRI Synthesis with Stable Diffusion (Fine-Tuned with Dataset Prompts) Fine-tuned version of Stable Diffusion v1-4 for brain MRI synthesis. It uses latent diffusion and dataset-specific prompts to generate realistic 256x256 FLAIR brain scans, with control over the dataset style.
This model is a fine-tuned version of Stable Diffusion v1-4 for prompt-conditioned brain MRI image synthesis, trained on 2D FLAIR slices from the SHIFTS, VH, and WMH2017 datasets. It uses latent diffusion to generate realistic 256×256 scans from latent representations of resolution 32×32 and includes special prompt tokens that allow control over the visual style.
🔍 Prompt Conditioning
Each training image was paired with a specific dataset prompt:
- "SHIFTS FLAIR MRI"
- "VH FLAIR MRI"
- "WMH2017 FLAIR MRI"
These prompts were added as new tokens in the tokenizer and trained jointly with the model, enabling conditional generation aligned with dataset distribution.
🧠 Training Details
- Base model: CompVis/stable-diffusion-v1-4
- Architecture: Latent Diffusion (U-Net + ResNet + Attention)
- Latent resolution: 32x32 (decoded to 256x256)
- Channels: 4
- Datasets: SHIFTS, VH, WMH2017 (FLAIR MRI)
- Epochs: 50
- Batch size: 8
- Gradient accumulation: 4
- Optimizer: AdamW
- LR: 1.0e-4
- Betas: (0.95, 0.999)
- Weight decay: 1.0e-6
- Epsilon: 1.0e-8
- LR Scheduler: Cosine decay with 500 warm-up steps
- Noise Scheduler: DDPM
- Timesteps: 1000
- Beta schedule: linear (β_start=0.0001, β_end=0.02)
- Gradient Clipping: Max norm 1.0
- Mixed Precision: Disabled
- Hardware: Single NVIDIA A30 GPU
✍️ Fine-Tuning Strategy
The text encoder, U-Net, and special prompt embeddings were trained jointly. Images were encoded into 32×32 latent space using a VAE and trained using latent diffusion.
🧪 Inference (Guided Sampling)
from diffusers import StableDiffusionPipeline
import torch
from torchvision.utils import save_image
pipe = StableDiffusionPipeline.from_pretrained("benetraco/latent_finetuning", torch_dtype=torch.float32).to("cuda")
pipe.scheduler.set_timesteps(999)
def get_embeddings(prompt):
tokens = pipe.tokenizer(prompt, return_tensors="pt", padding="max_length", max_length=77).to("cuda")
return pipe.text_encoder(**tokens).last_hidden_state
def sample(prompt, guidance_scale=2.0, seed=42):
torch.manual_seed(seed)
latent = torch.randn(1, 4, 32, 32).to("cuda") * pipe.scheduler.init_noise_sigma
text_emb = get_embeddings(prompt)
uncond_emb = get_embeddings("")
for t in pipe.scheduler.timesteps:
latent_in = pipe.scheduler.scale_model_input(latent, t)
with torch.no_grad():
noise_uncond = pipe.unet(latent_in, t, encoder_hidden_states=uncond_emb).sample
noise_text = pipe.unet(latent_in, t, encoder_hidden_states=text_emb).sample
noise = noise_uncond + guidance_scale * (noise_text - noise_uncond)
latent = pipe.scheduler.step(noise, t, latent).prev_sample
latent /= pipe.vae.config.scaling_factor
with torch.no_grad():
decoded = pipe.vae.decode(latent).sample
image = (decoded + 1.0) / 2.0
image = image.clamp(0, 1)
save_image(image, f"{prompt.replace(' ', '_')}_g{guidance_scale}.png")
sample("SHIFTS FLAIR MRI", guidance_scale=5.0)