benetraco commited on
Commit
e64a26c
verified
1 Parent(s): 9d5e04d

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +70 -0
README.md ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - pytorch
5
+ - diffusers
6
+ - stable-diffusion
7
+ - latent-diffusion
8
+ - medical-imaging
9
+ - brain-mri
10
+ - multiple-sclerosis
11
+ - dataset-conditioning
12
+ ---
13
+
14
+ # Brain MRI Synthesis with Stable Diffusion (fine-tuned with dataset prompts)
15
+
16
+ This model is a **fine-tuned version of Stable Diffusion v1-4** for **prompt-conditioned synthesis of brain MRI FLAIR slices**. It leverages **latent diffusion** and dataset-specific prompts to generate realistic 256x256 FLAIR scans with control over the source dataset's style or distribution.
17
+
18
+ ## 馃攳 Prompt Conditioning
19
+
20
+ The model introduces three special prompt tokens corresponding to the dataset of origin. During training, each image was paired with a prompt indicating its source:
21
+
22
+ - `"SHIFTS FLAIR MRI"`
23
+ - `"VH FLAIR MRI"`
24
+ - `"WMH2017 FLAIR MRI"`
25
+
26
+ These prompts were added as special tokens to the tokenizer, and their embeddings were fine-tuned alongside the U-Net, enabling dataset-specific synthesis.
27
+
28
+ ## 馃 Training Details
29
+
30
+ - **Base Model:** [`CompVis/stable-diffusion-v1-4`](https://huggingface.co/CompVis/stable-diffusion-v1-4)
31
+ - **Architecture:** Latent Diffusion with U-Net + ResNet + Attention
32
+ - **Input resolution (latent):** 32x32
33
+ - **Output resolution (decoded):** 256x256 pixels
34
+ - **Datasets:** SHIFTS, VH, and WMH2017 (FLAIR MRI slices)
35
+ - **Channels:** 4 latent channels
36
+ - **Epochs:** 50
37
+ - **Batch size:** 8
38
+ - **Gradient accumulation:** 4 steps
39
+ - **Optimizer:** AdamW
40
+ - Learning Rate: `1.0e-4`
41
+ - Betas: (0.95, 0.999)
42
+ - Weight Decay: `1.0e-6`
43
+ - Epsilon: `1.0e-8`
44
+ - **LR Scheduler:** Cosine schedule with 500 warm-up steps
45
+ - **Noise Scheduler:** DDPM with:
46
+ - `num_train_timesteps`: 1000
47
+ - `beta_start`: 0.0001
48
+ - `beta_end`: 0.02
49
+ - `beta_schedule`: "linear"
50
+ - **Mixed Precision:** Disabled
51
+ - **Gradient Clipping:** max norm 1.0
52
+ - **Hardware:** NVIDIA A30 GPU with 4 dataloader workers
53
+
54
+ ## 馃И Usage
55
+
56
+ You can use this model via the `diffusers` library for conditional generation:
57
+
58
+ ```python
59
+ from diffusers import DiffusionPipeline
60
+ import torch
61
+
62
+ # Load the model
63
+ pipe = DiffusionPipeline.from_pretrained("benetraco/latent_finetuning")
64
+ pipe.to("cuda") # or "cpu"
65
+
66
+ # Generate a brain MRI image in SHIFTS style
67
+ prompt = "SHIFTS FLAIR MRI"
68
+ image = pipe(prompt=prompt, num_inference_steps=50, guidance_scale=2.0).images[0]
69
+
70
+ image.show()