#: Brain MRI Synthesis with Stable Diffusion (Fine-Tuned with Dataset Prompts) Fine-tuned version of Stable Diffusion v1-4 for brain MRI synthesis. It uses latent diffusion and dataset-specific prompts to generate realistic 256x256 FLAIR brain scans, with control over the dataset style.
This model is a fine-tuned version of Stable Diffusion v1-4 for prompt-conditioned brain MRI image synthesis, trained on 2D FLAIR slices from the SHIFTS, VH, and WMH2017 datasets. It uses latent diffusion to generate realistic 256Γ256 scans from latent representations of resolution 32Γ32 and includes special prompt tokens that allow control over the visual style.
π Prompt Conditioning
Each training image was paired with a specific dataset prompt:
- "SHIFTS FLAIR MRI"
- "VH FLAIR MRI"
- "WMH2017 FLAIR MRI"
These prompts were added as new tokens in the tokenizer and trained jointly with the model, enabling conditional generation aligned with dataset distribution.
π§ Training Details
- Base model: CompVis/stable-diffusion-v1-4
- Architecture: Latent Diffusion (U-Net + ResNet + Attention)
- Latent resolution: 32x32 (decoded to 256x256)
- Channels: 4
- Datasets: SHIFTS, VH, WMH2017 (FLAIR MRI)
- Epochs: 50
- Batch size: 8
- Gradient accumulation: 4
- Optimizer: AdamW
- LR: 1.0e-4
- Betas: (0.95, 0.999)
- Weight decay: 1.0e-6
- Epsilon: 1.0e-8
- LR Scheduler: Cosine decay with 500 warm-up steps
- Noise Scheduler: DDPM
- Timesteps: 1000
- Beta schedule: linear (Ξ²_start=0.0001, Ξ²_end=0.02)
- Gradient Clipping: Max norm 1.0
- Mixed Precision: Disabled
- Hardware: Single NVIDIA A30 GPU
βοΈ Fine-Tuning Strategy
The text encoder, U-Net, and special prompt embeddings were trained jointly. Images were encoded into 32Γ32 latent space using a VAE and trained using latent diffusion.
π§ͺ Inference (Guided Sampling)
from diffusers import StableDiffusionPipeline
import torch
from torchvision.utils import save_image
pipe = StableDiffusionPipeline.from_pretrained("benetraco/latent_finetuning", torch_dtype=torch.float32).to("cuda")
pipe.scheduler.set_timesteps(999)
def get_embeddings(prompt):
tokens = pipe.tokenizer(prompt, return_tensors="pt", padding="max_length", max_length=77).to("cuda")
return pipe.text_encoder(**tokens).last_hidden_state
def sample(prompt, guidance_scale=2.0, seed=42):
torch.manual_seed(seed)
latent = torch.randn(1, 4, 32, 32).to("cuda") * pipe.scheduler.init_noise_sigma
text_emb = get_embeddings(prompt)
uncond_emb = get_embeddings("")
for t in pipe.scheduler.timesteps:
latent_in = pipe.scheduler.scale_model_input(latent, t)
with torch.no_grad():
noise_uncond = pipe.unet(latent_in, t, encoder_hidden_states=uncond_emb).sample
noise_text = pipe.unet(latent_in, t, encoder_hidden_states=text_emb).sample
noise = noise_uncond + guidance_scale * (noise_text - noise_uncond)
latent = pipe.scheduler.step(noise, t, latent).prev_sample
latent /= pipe.vae.config.scaling_factor
with torch.no_grad():
decoded = pipe.vae.decode(latent).sample
image = (decoded + 1.0) / 2.0
image = image.clamp(0, 1)
save_image(image, f"{prompt.replace(' ', '_')}_g{guidance_scale}.png")
sample("SHIFTS FLAIR MRI", guidance_scale=5.0)
- Downloads last month
- 18