#: Brain MRI Synthesis with Stable Diffusion (Fine-Tuned with Dataset Prompts) Fine-tuned version of Stable Diffusion v1-4 for brain MRI synthesis. It uses latent diffusion and dataset-specific prompts to generate realistic 256x256 FLAIR brain scans, with control over the dataset style.

This model is a fine-tuned version of Stable Diffusion v1-4 for prompt-conditioned brain MRI image synthesis, trained on 2D FLAIR slices from the SHIFTS, VH, and WMH2017 datasets. It uses latent diffusion to generate realistic 256Γ—256 scans from latent representations of resolution 32Γ—32 and includes special prompt tokens that allow control over the visual style.

πŸ” Prompt Conditioning

Each training image was paired with a specific dataset prompt:

  • "SHIFTS FLAIR MRI"
  • "VH FLAIR MRI"
  • "WMH2017 FLAIR MRI"

These prompts were added as new tokens in the tokenizer and trained jointly with the model, enabling conditional generation aligned with dataset distribution.

🧠 Training Details

  • Base model: CompVis/stable-diffusion-v1-4
  • Architecture: Latent Diffusion (U-Net + ResNet + Attention)
  • Latent resolution: 32x32 (decoded to 256x256)
  • Channels: 4
  • Datasets: SHIFTS, VH, WMH2017 (FLAIR MRI)
  • Epochs: 50
  • Batch size: 8
  • Gradient accumulation: 4
  • Optimizer: AdamW
    • LR: 1.0e-4
    • Betas: (0.95, 0.999)
    • Weight decay: 1.0e-6
    • Epsilon: 1.0e-8
  • LR Scheduler: Cosine decay with 500 warm-up steps
  • Noise Scheduler: DDPM
    • Timesteps: 1000
    • Beta schedule: linear (Ξ²_start=0.0001, Ξ²_end=0.02)
  • Gradient Clipping: Max norm 1.0
  • Mixed Precision: Disabled
  • Hardware: Single NVIDIA A30 GPU (4 dataloader workers)

✍️ Fine-Tuning Strategy

The text encoder, U-Net, and special prompt embeddings were trained jointly. Images were encoded into 32Γ—32 latent space using a VAE and trained using latent diffusion.

πŸ§ͺ Inference (Guided Sampling)

from diffusers import StableDiffusionPipeline
import torch
from torchvision.utils import save_image

pipe = StableDiffusionPipeline.from_pretrained("benetraco/latent_finetuning", torch_dtype=torch.float32).to("cuda")
pipe.scheduler.set_timesteps(999)

def get_embeddings(prompt):
    tokens = pipe.tokenizer(prompt, return_tensors="pt", padding="max_length", max_length=77).to("cuda")
    return pipe.text_encoder(**tokens).last_hidden_state

def sample(prompt, guidance_scale=2.0, seed=42):
    torch.manual_seed(seed)
    latent = torch.randn(1, 4, 32, 32).to("cuda") * pipe.scheduler.init_noise_sigma
    text_emb = get_embeddings(prompt)
    uncond_emb = get_embeddings("")

    for t in pipe.scheduler.timesteps:
        latent_in = pipe.scheduler.scale_model_input(latent, t)
        with torch.no_grad():
            noise_uncond = pipe.unet(latent_in, t, encoder_hidden_states=uncond_emb).sample
            noise_text = pipe.unet(latent_in, t, encoder_hidden_states=text_emb).sample
            noise = noise_uncond + guidance_scale * (noise_text - noise_uncond)
        latent = pipe.scheduler.step(noise, t, latent).prev_sample

    latent /= pipe.vae.config.scaling_factor
    with torch.no_grad():
        decoded = pipe.vae.decode(latent).sample
    image = (decoded + 1.0) / 2.0
    image = image.clamp(0, 1)
    save_image(image, f"{prompt.replace(' ', '_')}_g{guidance_scale}.png")

sample("SHIFTS FLAIR MRI", guidance_scale=5.0)
Downloads last month
10
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support