benetraco's picture
Upload README.md with huggingface_hub
feb0db6 verified
metadata
license: mit
tags:
  - pytorch
  - diffusers
  - stable-diffusion
  - latent-diffusion
  - medical-imaging
  - brain-mri
  - multiple-sclerosis
  - dataset-conditioning

#: Brain MRI Synthesis with Stable Diffusion (Fine-Tuned with Dataset Prompts) Fine-tuned version of Stable Diffusion v1-4 for brain MRI synthesis. It uses latent diffusion and dataset-specific prompts to generate realistic 256x256 FLAIR brain scans, with control over the dataset style.

This model is a fine-tuned version of Stable Diffusion v1-4 for prompt-conditioned brain MRI image synthesis, trained on 2D FLAIR slices from the SHIFTS, VH, and WMH2017 datasets. It uses latent diffusion to generate realistic 256×256 scans from latent representations of resolution 32×32 and includes special prompt tokens that allow control over the visual style.

🔍 Prompt Conditioning

Each training image was paired with a specific dataset prompt:

  • "SHIFTS FLAIR MRI"
  • "VH FLAIR MRI"
  • "WMH2017 FLAIR MRI"

These prompts were added as new tokens in the tokenizer and trained jointly with the model, enabling conditional generation aligned with dataset distribution.

🧠 Training Details

  • Base model: CompVis/stable-diffusion-v1-4
  • Architecture: Latent Diffusion (U-Net + ResNet + Attention)
  • Latent resolution: 32x32 (decoded to 256x256)
  • Channels: 4
  • Datasets: SHIFTS, VH, WMH2017 (FLAIR MRI)
  • Epochs: 50
  • Batch size: 8
  • Gradient accumulation: 4
  • Optimizer: AdamW
    • LR: 1.0e-4
    • Betas: (0.95, 0.999)
    • Weight decay: 1.0e-6
    • Epsilon: 1.0e-8
  • LR Scheduler: Cosine decay with 500 warm-up steps
  • Noise Scheduler: DDPM
    • Timesteps: 1000
    • Beta schedule: linear (β_start=0.0001, β_end=0.02)
  • Gradient Clipping: Max norm 1.0
  • Mixed Precision: Disabled
  • Hardware: Single NVIDIA A30 GPU

✍️ Fine-Tuning Strategy

The text encoder, U-Net, and special prompt embeddings were trained jointly. Images were encoded into 32×32 latent space using a VAE and trained using latent diffusion.

🧪 Inference (Guided Sampling)

from diffusers import StableDiffusionPipeline
import torch
from torchvision.utils import save_image

pipe = StableDiffusionPipeline.from_pretrained("benetraco/latent_finetuning", torch_dtype=torch.float32).to("cuda")
pipe.scheduler.set_timesteps(999)

def get_embeddings(prompt):
    tokens = pipe.tokenizer(prompt, return_tensors="pt", padding="max_length", max_length=77).to("cuda")
    return pipe.text_encoder(**tokens).last_hidden_state

def sample(prompt, guidance_scale=2.0, seed=42):
    torch.manual_seed(seed)
    latent = torch.randn(1, 4, 32, 32).to("cuda") * pipe.scheduler.init_noise_sigma
    text_emb = get_embeddings(prompt)
    uncond_emb = get_embeddings("")

    for t in pipe.scheduler.timesteps:
        latent_in = pipe.scheduler.scale_model_input(latent, t)
        with torch.no_grad():
            noise_uncond = pipe.unet(latent_in, t, encoder_hidden_states=uncond_emb).sample
            noise_text = pipe.unet(latent_in, t, encoder_hidden_states=text_emb).sample
            noise = noise_uncond + guidance_scale * (noise_text - noise_uncond)
        latent = pipe.scheduler.step(noise, t, latent).prev_sample

    latent /= pipe.vae.config.scaling_factor
    with torch.no_grad():
        decoded = pipe.vae.decode(latent).sample
    image = (decoded + 1.0) / 2.0
    image = image.clamp(0, 1)
    save_image(image, f"{prompt.replace(' ', '_')}_g{guidance_scale}.png")

sample("SHIFTS FLAIR MRI", guidance_scale=5.0)