import os import torch import json import numpy as np from PIL import Image as I from pathlib import Path from tqdm.auto import tqdm import matplotlib.pyplot as plt import torch.nn.functional as F from datasets import load_dataset from dataclasses import dataclass from accelerate import Accelerator from torchvision import transforms from skimage.color import rgb2lab, lab2rgb # from huggingface_hub import HfFolder, Repository, whoami from diffusers import DDPMPipeline, UNet2DModel from diffusers.optimization import get_cosine_schedule_with_warmup from eval import evaluate @dataclass class TrainingConfig: image_size = 128 # the generated image resolution train_batch_size = 8 eval_batch_size = 8 # how many images to sample during evaluation num_epochs = 512 gradient_accumulation_steps = 1 learning_rate = 3.3e-5 lr_warmup_steps = 500 save_image_epochs = 16 save_model_epochs = 16 mixed_precision = "fp16" # `no` for float32, `fp16` for automatic mixed precision output_dir = "m1guelpf_nouns" # the model name locally and on the HF Hub push_to_hub = False # whether to upload the saved model to the HF Hub hub_private_repo = False overwrite_output_dir = True # overwrite the old model when re-running the notebook seed = 0 dataset_output_dir = "datasets/" dataset_name = "m1guelpf/nouns" model_url = "mrm8488/ddpm-ema-butterflies-128" model_config = "models/model_config.json" config = TrainingConfig() def save_plot(images): fig, axs = plt.subplots(1, 4, figsize=(16, 4)) for i, image in enumerate(images): axs[i].imshow(image) axs[i].set_axis_off() fig.show() def transform_stc(batch): tfms = transforms.Compose( [ transforms.Resize((config.image_size, config.image_size)), # transforms.ColorJitter( # brightness=0.3, contrast=0.1, saturation=(1.0, 2.0), hue=0.05 # ), transforms.ToTensor(), ] ) rgb_images = [ tfms(I.fromarray(rgb2lab(image.convert("RGB")).astype(np.uint8))) for image in batch["image"] ] gray_images = [ tfms(I.fromarray(rgb2lab(image.convert("L").convert("RGB")).astype(np.uint8))) for image in batch["image"] ] return {"rgb": rgb_images, "gray": gray_images} def load_weights(pretrained_model, uninitilized_model): for name, param in pretrained_model.state_dict().items(): if param.shape == uninitilized_model.state_dict()[name].shape: uninitilized_model.state_dict()[name].copy_(param) return uninitilized_model def load_pipline(config): pipeline = DDPMPipeline.from_pretrained(config.model_url) return pipeline def train_loop( config, model, noise_scheduler, optimizer, train_dataloader, test_dataloader, lr_scheduler, ): accelerator = Accelerator( # mixed_precision=config.mixed_precision, gradient_accumulation_steps=config.gradient_accumulation_steps, log_with="tensorboard", project_dir=os.path.join(config.output_dir, "logs"), ) device = accelerator.device if accelerator.is_main_process: if config.output_dir is not None: os.makedirs(config.output_dir, exist_ok=True) accelerator.init_trackers("train_example") model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( model, optimizer, train_dataloader, lr_scheduler ) global_step = 0 for epoch in range(config.num_epochs): progress_bar = tqdm( total=len(train_dataloader), disable=not accelerator.is_local_main_process ) progress_bar.set_description(f"Epoch {epoch}") for step, batch in enumerate(train_dataloader): rgb_images = batch["rgb"] rgb_l = rgb_images[:, :1] rgb_ab = rgb_images[:, 1:] # Sample noise to add to the gray_images ab channel noise = torch.randn(rgb_ab.shape).to(device) bs = rgb_images.shape[0] # Sample a random timestep for each image timesteps = torch.randint( 0, noise_scheduler.config.num_train_timesteps, (bs,), device=device, ).long() noisy_images = torch.cat( [rgb_l, noise_scheduler.add_noise(rgb_ab, noise, timesteps)], dim=1 ) with accelerator.accumulate(model): # Predict the noise residual noise_pred = model( noisy_images, timesteps, return_dict=False, )[0] loss = F.mse_loss(noise_pred, noise) accelerator.backward(loss) accelerator.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() lr_scheduler.step() optimizer.zero_grad() progress_bar.update(1) logs = { "loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step, } progress_bar.set_postfix(**logs) accelerator.log(logs, step=global_step) global_step += 1 # After each epoch you optionally sample some demo images with evaluate() and save the model if accelerator.is_main_process: pipeline = DDPMPipeline( unet=accelerator.unwrap_model(model), scheduler=noise_scheduler ) if ( epoch + 1 ) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1: for batch in test_dataloader: eval_images = batch["gray"].to(device) break evaluate(eval_images, config, epoch, pipeline) if ( epoch + 1 ) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1: pipeline.save_pretrained(config.output_dir) def main(): dataset = load_dataset(config.dataset_name, split="train") dataset = dataset.train_test_split(0.02) dataset.set_transform(transform_stc) train_dataloader = torch.utils.data.DataLoader( dataset["train"], batch_size=config.train_batch_size, shuffle=True ) test_dataloader = torch.utils.data.DataLoader( dataset["test"], batch_size=config.train_batch_size, shuffle=False ) pipeline = load_pipline(config) pretrained_model = pipeline.unet with open(config.model_config) as rstream: model_config = json.load(rstream) model = UNet2DModel.from_config(model_config) model = load_weights(pretrained_model, model) noise_scheduler = pipeline.scheduler optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate) lr_scheduler = get_cosine_schedule_with_warmup( optimizer=optimizer, num_warmup_steps=config.lr_warmup_steps, num_training_steps=(len(train_dataloader) * config.num_epochs), ) train_loop( config=config, model=model, noise_scheduler=noise_scheduler, optimizer=optimizer, train_dataloader=train_dataloader, test_dataloader=test_dataloader, lr_scheduler=lr_scheduler, ) # notebook_launcher(train_loop, args, num_processes=1) if __name__ == "__main__": main()