File size: 7,470 Bytes
5a28eea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 |
import os
import torch
import json
import numpy as np
from PIL import Image as I
from pathlib import Path
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import torch.nn.functional as F
from datasets import load_dataset
from dataclasses import dataclass
from accelerate import Accelerator
from torchvision import transforms
from skimage.color import rgb2lab, lab2rgb
# from huggingface_hub import HfFolder, Repository, whoami
from diffusers import DDPMPipeline, UNet2DModel
from diffusers.optimization import get_cosine_schedule_with_warmup
from eval import evaluate
@dataclass
class TrainingConfig:
image_size = 128 # the generated image resolution
train_batch_size = 8
eval_batch_size = 8 # how many images to sample during evaluation
num_epochs = 512
gradient_accumulation_steps = 1
learning_rate = 3.3e-5
lr_warmup_steps = 500
save_image_epochs = 16
save_model_epochs = 16
mixed_precision = "fp16" # `no` for float32, `fp16` for automatic mixed precision
output_dir = "m1guelpf_nouns" # the model name locally and on the HF Hub
push_to_hub = False # whether to upload the saved model to the HF Hub
hub_private_repo = False
overwrite_output_dir = True # overwrite the old model when re-running the notebook
seed = 0
dataset_output_dir = "datasets/"
dataset_name = "m1guelpf/nouns"
model_url = "mrm8488/ddpm-ema-butterflies-128"
model_config = "models/model_config.json"
config = TrainingConfig()
def save_plot(images):
fig, axs = plt.subplots(1, 4, figsize=(16, 4))
for i, image in enumerate(images):
axs[i].imshow(image)
axs[i].set_axis_off()
fig.show()
def transform_stc(batch):
tfms = transforms.Compose(
[
transforms.Resize((config.image_size, config.image_size)),
# transforms.ColorJitter(
# brightness=0.3, contrast=0.1, saturation=(1.0, 2.0), hue=0.05
# ),
transforms.ToTensor(),
]
)
rgb_images = [
tfms(I.fromarray(rgb2lab(image.convert("RGB")).astype(np.uint8)))
for image in batch["image"]
]
gray_images = [
tfms(I.fromarray(rgb2lab(image.convert("L").convert("RGB")).astype(np.uint8)))
for image in batch["image"]
]
return {"rgb": rgb_images, "gray": gray_images}
def load_weights(pretrained_model, uninitilized_model):
for name, param in pretrained_model.state_dict().items():
if param.shape == uninitilized_model.state_dict()[name].shape:
uninitilized_model.state_dict()[name].copy_(param)
return uninitilized_model
def load_pipline(config):
pipeline = DDPMPipeline.from_pretrained(config.model_url)
return pipeline
def train_loop(
config,
model,
noise_scheduler,
optimizer,
train_dataloader,
test_dataloader,
lr_scheduler,
):
accelerator = Accelerator(
# mixed_precision=config.mixed_precision,
gradient_accumulation_steps=config.gradient_accumulation_steps,
log_with="tensorboard",
project_dir=os.path.join(config.output_dir, "logs"),
)
device = accelerator.device
if accelerator.is_main_process:
if config.output_dir is not None:
os.makedirs(config.output_dir, exist_ok=True)
accelerator.init_trackers("train_example")
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, lr_scheduler
)
global_step = 0
for epoch in range(config.num_epochs):
progress_bar = tqdm(
total=len(train_dataloader), disable=not accelerator.is_local_main_process
)
progress_bar.set_description(f"Epoch {epoch}")
for step, batch in enumerate(train_dataloader):
rgb_images = batch["rgb"]
rgb_l = rgb_images[:, :1]
rgb_ab = rgb_images[:, 1:]
# Sample noise to add to the gray_images ab channel
noise = torch.randn(rgb_ab.shape).to(device)
bs = rgb_images.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(
0,
noise_scheduler.config.num_train_timesteps,
(bs,),
device=device,
).long()
noisy_images = torch.cat(
[rgb_l, noise_scheduler.add_noise(rgb_ab, noise, timesteps)], dim=1
)
with accelerator.accumulate(model):
# Predict the noise residual
noise_pred = model(
noisy_images,
timesteps,
return_dict=False,
)[0]
loss = F.mse_loss(noise_pred, noise)
accelerator.backward(loss)
accelerator.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
progress_bar.update(1)
logs = {
"loss": loss.detach().item(),
"lr": lr_scheduler.get_last_lr()[0],
"step": global_step,
}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
global_step += 1
# After each epoch you optionally sample some demo images with evaluate() and save the model
if accelerator.is_main_process:
pipeline = DDPMPipeline(
unet=accelerator.unwrap_model(model), scheduler=noise_scheduler
)
if (
epoch + 1
) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:
for batch in test_dataloader:
eval_images = batch["gray"].to(device)
break
evaluate(eval_images, config, epoch, pipeline)
if (
epoch + 1
) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
pipeline.save_pretrained(config.output_dir)
def main():
dataset = load_dataset(config.dataset_name, split="train")
dataset = dataset.train_test_split(0.02)
dataset.set_transform(transform_stc)
train_dataloader = torch.utils.data.DataLoader(
dataset["train"], batch_size=config.train_batch_size, shuffle=True
)
test_dataloader = torch.utils.data.DataLoader(
dataset["test"], batch_size=config.train_batch_size, shuffle=False
)
pipeline = load_pipline(config)
pretrained_model = pipeline.unet
with open(config.model_config) as rstream:
model_config = json.load(rstream)
model = UNet2DModel.from_config(model_config)
model = load_weights(pretrained_model, model)
noise_scheduler = pipeline.scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
lr_scheduler = get_cosine_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=config.lr_warmup_steps,
num_training_steps=(len(train_dataloader) * config.num_epochs),
)
train_loop(
config=config,
model=model,
noise_scheduler=noise_scheduler,
optimizer=optimizer,
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
lr_scheduler=lr_scheduler,
)
# notebook_launcher(train_loop, args, num_processes=1)
if __name__ == "__main__":
main()
|