|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import datetime |
|
import getpass |
|
import hashlib |
|
import json |
|
import os |
|
import os.path as osp |
|
import random |
|
import time |
|
import types |
|
import warnings |
|
from dataclasses import asdict |
|
from pathlib import Path |
|
|
|
import numpy as np |
|
import pyrallis |
|
import torch |
|
from accelerate import Accelerator, InitProcessGroupKwargs |
|
from accelerate.utils import DistributedType |
|
from PIL import Image |
|
from termcolor import colored |
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
from diffusion import DPMS, FlowEuler, Scheduler |
|
from diffusion.data.builder import build_dataloader, build_dataset |
|
from diffusion.data.wids import DistributedRangedSampler |
|
from diffusion.model.builder import build_model, get_tokenizer_and_text_encoder, get_vae, vae_decode, vae_encode |
|
from diffusion.model.respace import compute_density_for_timestep_sampling |
|
from diffusion.utils.checkpoint import load_checkpoint, save_checkpoint |
|
from diffusion.utils.config import SanaConfig |
|
from diffusion.utils.data_sampler import AspectRatioBatchSampler |
|
from diffusion.utils.dist_utils import clip_grad_norm_, flush, get_world_size |
|
from diffusion.utils.logger import LogBuffer, get_root_logger |
|
from diffusion.utils.lr_scheduler import build_lr_scheduler |
|
from diffusion.utils.misc import DebugUnderflowOverflow, init_random_seed, read_config, set_random_seed |
|
from diffusion.utils.optimizer import auto_scale_lr, build_optimizer |
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
def set_fsdp_env(): |
|
os.environ["ACCELERATE_USE_FSDP"] = "true" |
|
os.environ["FSDP_AUTO_WRAP_POLICY"] = "TRANSFORMER_BASED_WRAP" |
|
os.environ["FSDP_BACKWARD_PREFETCH"] = "BACKWARD_PRE" |
|
os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = "SanaBlock" |
|
|
|
|
|
@torch.inference_mode() |
|
def log_validation(accelerator, config, model, logger, step, device, vae=None, init_noise=None): |
|
torch.cuda.empty_cache() |
|
vis_sampler = config.scheduler.vis_sampler |
|
model = accelerator.unwrap_model(model).eval() |
|
hw = torch.tensor([[image_size, image_size]], dtype=torch.float, device=device).repeat(1, 1) |
|
ar = torch.tensor([[1.0]], device=device).repeat(1, 1) |
|
null_y = torch.load(null_embed_path, map_location="cpu") |
|
null_y = null_y["uncond_prompt_embeds"].to(device) |
|
|
|
|
|
logger.info("Running validation... ") |
|
image_logs = [] |
|
|
|
def run_sampling(init_z=None, label_suffix="", vae=None, sampler="dpm-solver"): |
|
latents = [] |
|
current_image_logs = [] |
|
for prompt in validation_prompts: |
|
z = ( |
|
torch.randn(1, config.vae.vae_latent_dim, latent_size, latent_size, device=device) |
|
if init_z is None |
|
else init_z |
|
) |
|
embed = torch.load( |
|
osp.join(config.train.valid_prompt_embed_root, f"{prompt[:50]}_{valid_prompt_embed_suffix}"), |
|
map_location="cpu", |
|
) |
|
caption_embs, emb_masks = embed["caption_embeds"].to(device), embed["emb_mask"].to(device) |
|
|
|
|
|
model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks) |
|
|
|
if sampler == "dpm-solver": |
|
dpm_solver = DPMS( |
|
model.forward_with_dpmsolver, |
|
condition=caption_embs, |
|
uncondition=null_y, |
|
cfg_scale=4.5, |
|
model_kwargs=model_kwargs, |
|
) |
|
denoised = dpm_solver.sample( |
|
z, |
|
steps=14, |
|
order=2, |
|
skip_type="time_uniform", |
|
method="multistep", |
|
) |
|
elif sampler == "flow_euler": |
|
flow_solver = FlowEuler( |
|
model, condition=caption_embs, uncondition=null_y, cfg_scale=4.5, model_kwargs=model_kwargs |
|
) |
|
denoised = flow_solver.sample(z, steps=28) |
|
elif sampler == "flow_dpm-solver": |
|
dpm_solver = DPMS( |
|
model.forward_with_dpmsolver, |
|
condition=caption_embs, |
|
uncondition=null_y, |
|
cfg_scale=4.5, |
|
model_type="flow", |
|
model_kwargs=model_kwargs, |
|
schedule="FLOW", |
|
) |
|
denoised = dpm_solver.sample( |
|
z, |
|
steps=20, |
|
order=2, |
|
skip_type="time_uniform_flow", |
|
method="multistep", |
|
flow_shift=config.scheduler.flow_shift, |
|
) |
|
else: |
|
raise ValueError(f"{sampler} not implemented") |
|
|
|
latents.append(denoised) |
|
torch.cuda.empty_cache() |
|
if vae is None: |
|
vae = get_vae(config.vae.vae_type, config.vae.vae_pretrained, accelerator.device).to(torch.float16) |
|
for prompt, latent in zip(validation_prompts, latents): |
|
latent = latent.to(torch.float16) |
|
samples = vae_decode(config.vae.vae_type, vae, latent) |
|
samples = ( |
|
torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()[0] |
|
) |
|
image = Image.fromarray(samples) |
|
current_image_logs.append({"validation_prompt": prompt + label_suffix, "images": [image]}) |
|
|
|
return current_image_logs |
|
|
|
|
|
image_logs += run_sampling(init_z=None, label_suffix="", vae=vae, sampler=vis_sampler) |
|
|
|
|
|
if init_noise is not None: |
|
init_noise = torch.clone(init_noise).to(device) |
|
image_logs += run_sampling(init_z=init_noise, label_suffix=" w/ init noise", vae=vae, sampler=vis_sampler) |
|
|
|
formatted_images = [] |
|
for log in image_logs: |
|
images = log["images"] |
|
validation_prompt = log["validation_prompt"] |
|
for image in images: |
|
formatted_images.append((validation_prompt, np.asarray(image))) |
|
|
|
for tracker in accelerator.trackers: |
|
if tracker.name == "tensorboard": |
|
for validation_prompt, image in formatted_images: |
|
tracker.writer.add_images(validation_prompt, image[None, ...], step, dataformats="NHWC") |
|
elif tracker.name == "wandb": |
|
import wandb |
|
|
|
wandb_images = [] |
|
for validation_prompt, image in formatted_images: |
|
wandb_images.append(wandb.Image(image, caption=validation_prompt, file_type="jpg")) |
|
tracker.log({"validation": wandb_images}) |
|
else: |
|
logger.warn(f"image logging not implemented for {tracker.name}") |
|
|
|
def concatenate_images(image_caption, images_per_row=5, image_format="webp"): |
|
import io |
|
|
|
images = [log["images"][0] for log in image_caption] |
|
if images[0].size[0] > 1024: |
|
images = [image.resize((1024, 1024)) for image in images] |
|
|
|
widths, heights = zip(*(img.size for img in images)) |
|
max_width = max(widths) |
|
total_height = sum(heights[i : i + images_per_row][0] for i in range(0, len(images), images_per_row)) |
|
|
|
new_im = Image.new("RGB", (max_width * images_per_row, total_height)) |
|
|
|
y_offset = 0 |
|
for i in range(0, len(images), images_per_row): |
|
row_images = images[i : i + images_per_row] |
|
x_offset = 0 |
|
for img in row_images: |
|
new_im.paste(img, (x_offset, y_offset)) |
|
x_offset += max_width |
|
y_offset += heights[i] |
|
webp_image_bytes = io.BytesIO() |
|
new_im.save(webp_image_bytes, format=image_format) |
|
webp_image_bytes.seek(0) |
|
new_im = Image.open(webp_image_bytes) |
|
|
|
return new_im |
|
|
|
if config.train.local_save_vis: |
|
file_format = "webp" |
|
local_vis_save_path = osp.join(config.work_dir, "log_vis") |
|
os.umask(0o000) |
|
os.makedirs(local_vis_save_path, exist_ok=True) |
|
concatenated_image = concatenate_images(image_logs, images_per_row=5, image_format=file_format) |
|
save_path = ( |
|
osp.join(local_vis_save_path, f"vis_{step}.{file_format}") |
|
if init_noise is None |
|
else osp.join(local_vis_save_path, f"vis_{step}_w_init.{file_format}") |
|
) |
|
concatenated_image.save(save_path) |
|
|
|
del vae |
|
flush() |
|
return image_logs |
|
|
|
|
|
def train(config, args, accelerator, model, optimizer, lr_scheduler, train_dataloader, train_diffusion, logger): |
|
if getattr(config.train, "debug_nan", False): |
|
DebugUnderflowOverflow(model) |
|
logger.info("NaN debugger registered. Start to detect overflow during training.") |
|
log_buffer = LogBuffer() |
|
|
|
global_step = start_step + 1 |
|
skip_step = max(config.train.skip_step, global_step) % train_dataloader_len |
|
skip_step = skip_step if skip_step < (train_dataloader_len - 20) else 0 |
|
loss_nan_timer = 0 |
|
|
|
|
|
if args.caching and config.model.multi_scale: |
|
caching_start = time.time() |
|
logger.info( |
|
f"Start caching your dataset for batch_sampler at {cache_file}. \n" |
|
f"This may take a lot of time...No training will launch" |
|
) |
|
train_dataloader.batch_sampler.sampler.set_start(max(train_dataloader.batch_sampler.exist_ids, 0)) |
|
accelerator.wait_for_everyone() |
|
for index, _ in enumerate(train_dataloader): |
|
accelerator.wait_for_everyone() |
|
if index % 2000 == 0: |
|
logger.info( |
|
f"rank: {rank}, Cached file len: {len(train_dataloader.batch_sampler.cached_idx)} / {len(train_dataloader)}" |
|
) |
|
print( |
|
f"rank: {rank}, Cached file len: {len(train_dataloader.batch_sampler.cached_idx)} / {len(train_dataloader)}" |
|
) |
|
if (time.time() - caching_start) / 3600 > 3.7: |
|
json.dump(train_dataloader.batch_sampler.cached_idx, open(cache_file, "w"), indent=4) |
|
accelerator.wait_for_everyone() |
|
break |
|
if len(train_dataloader.batch_sampler.cached_idx) == len(train_dataloader) - 1000: |
|
logger.info( |
|
f"Saving rank: {rank}, Cached file len: {len(train_dataloader.batch_sampler.cached_idx)} / {len(train_dataloader)}" |
|
) |
|
json.dump(train_dataloader.batch_sampler.cached_idx, open(cache_file, "w"), indent=4) |
|
accelerator.wait_for_everyone() |
|
continue |
|
accelerator.wait_for_everyone() |
|
print(f"Saving rank-{rank} Cached file len: {len(train_dataloader.batch_sampler.cached_idx)}") |
|
json.dump(train_dataloader.batch_sampler.cached_idx, open(cache_file, "w"), indent=4) |
|
return |
|
|
|
|
|
for epoch in range(start_epoch + 1, config.train.num_epochs + 1): |
|
time_start, last_tic = time.time(), time.time() |
|
sampler = ( |
|
train_dataloader.batch_sampler.sampler |
|
if (num_replicas > 1 or config.model.multi_scale) |
|
else train_dataloader.sampler |
|
) |
|
sampler.set_epoch(epoch) |
|
sampler.set_start(max((skip_step - 1) * config.train.train_batch_size, 0)) |
|
if skip_step > 1 and accelerator.is_main_process: |
|
logger.info(f"Skipped Steps: {skip_step}") |
|
skip_step = 1 |
|
data_time_start = time.time() |
|
data_time_all = 0 |
|
lm_time_all = 0 |
|
vae_time_all = 0 |
|
model_time_all = 0 |
|
for step, batch in enumerate(train_dataloader): |
|
|
|
accelerator.wait_for_everyone() |
|
data_time_all += time.time() - data_time_start |
|
vae_time_start = time.time() |
|
if load_vae_feat: |
|
z = batch[0].to(accelerator.device) |
|
else: |
|
with torch.no_grad(): |
|
with torch.amp.autocast( |
|
"cuda", |
|
enabled=(config.model.mixed_precision == "fp16" or config.model.mixed_precision == "bf16"), |
|
): |
|
z = vae_encode( |
|
config.vae.vae_type, vae, batch[0], config.vae.sample_posterior, accelerator.device |
|
) |
|
|
|
accelerator.wait_for_everyone() |
|
vae_time_all += time.time() - vae_time_start |
|
|
|
clean_images = z |
|
data_info = batch[3] |
|
|
|
lm_time_start = time.time() |
|
if load_text_feat: |
|
y = batch[1] |
|
y_mask = batch[2] |
|
else: |
|
if "T5" in config.text_encoder.text_encoder_name: |
|
with torch.no_grad(): |
|
txt_tokens = tokenizer( |
|
batch[1], max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" |
|
).to(accelerator.device) |
|
y = text_encoder(txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask)[0][:, None] |
|
y_mask = txt_tokens.attention_mask[:, None, None] |
|
elif ( |
|
"gemma" in config.text_encoder.text_encoder_name or "Qwen" in config.text_encoder.text_encoder_name |
|
): |
|
with torch.no_grad(): |
|
if not config.text_encoder.chi_prompt: |
|
max_length_all = config.text_encoder.model_max_length |
|
prompt = batch[1] |
|
else: |
|
chi_prompt = "\n".join(config.text_encoder.chi_prompt) |
|
prompt = [chi_prompt + i for i in batch[1]] |
|
num_chi_prompt_tokens = len(tokenizer.encode(chi_prompt)) |
|
max_length_all = ( |
|
num_chi_prompt_tokens + config.text_encoder.model_max_length - 2 |
|
) |
|
txt_tokens = tokenizer( |
|
prompt, |
|
padding="max_length", |
|
max_length=max_length_all, |
|
truncation=True, |
|
return_tensors="pt", |
|
).to(accelerator.device) |
|
select_index = [0] + list( |
|
range(-config.text_encoder.model_max_length + 1, 0) |
|
) |
|
y = text_encoder(txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask)[0][:, None][ |
|
:, :, select_index |
|
] |
|
y_mask = txt_tokens.attention_mask[:, None, None][:, :, :, select_index] |
|
else: |
|
print("error") |
|
exit() |
|
|
|
|
|
bs = clean_images.shape[0] |
|
timesteps = torch.randint( |
|
0, config.scheduler.train_sampling_steps, (bs,), device=clean_images.device |
|
).long() |
|
if config.scheduler.weighting_scheme in ["logit_normal"]: |
|
|
|
u = compute_density_for_timestep_sampling( |
|
weighting_scheme=config.scheduler.weighting_scheme, |
|
batch_size=bs, |
|
logit_mean=config.scheduler.logit_mean, |
|
logit_std=config.scheduler.logit_std, |
|
mode_scale=None, |
|
) |
|
timesteps = (u * config.scheduler.train_sampling_steps).long().to(clean_images.device) |
|
grad_norm = None |
|
accelerator.wait_for_everyone() |
|
lm_time_all += time.time() - lm_time_start |
|
model_time_start = time.time() |
|
with accelerator.accumulate(model): |
|
|
|
optimizer.zero_grad() |
|
loss_term = train_diffusion.training_losses( |
|
model, clean_images, timesteps, model_kwargs=dict(y=y, mask=y_mask, data_info=data_info) |
|
) |
|
loss = loss_term["loss"].mean() |
|
accelerator.backward(loss) |
|
if accelerator.sync_gradients: |
|
grad_norm = accelerator.clip_grad_norm_(model.parameters(), config.train.gradient_clip) |
|
optimizer.step() |
|
lr_scheduler.step() |
|
accelerator.wait_for_everyone() |
|
model_time_all += time.time() - model_time_start |
|
|
|
if torch.any(torch.isnan(loss)): |
|
loss_nan_timer += 1 |
|
lr = lr_scheduler.get_last_lr()[0] |
|
logs = {args.loss_report_name: accelerator.gather(loss).mean().item()} |
|
if grad_norm is not None: |
|
logs.update(grad_norm=accelerator.gather(grad_norm).mean().item()) |
|
log_buffer.update(logs) |
|
if (step + 1) % config.train.log_interval == 0 or (step + 1) == 1: |
|
accelerator.wait_for_everyone() |
|
t = (time.time() - last_tic) / config.train.log_interval |
|
t_d = data_time_all / config.train.log_interval |
|
t_m = model_time_all / config.train.log_interval |
|
t_lm = lm_time_all / config.train.log_interval |
|
t_vae = vae_time_all / config.train.log_interval |
|
avg_time = (time.time() - time_start) / (step + 1) |
|
eta = str(datetime.timedelta(seconds=int(avg_time * (total_steps - global_step - 1)))) |
|
eta_epoch = str( |
|
datetime.timedelta( |
|
seconds=int( |
|
avg_time |
|
* (train_dataloader_len - sampler.step_start // config.train.train_batch_size - step - 1) |
|
) |
|
) |
|
) |
|
log_buffer.average() |
|
|
|
current_step = ( |
|
global_step - sampler.step_start // config.train.train_batch_size |
|
) % train_dataloader_len |
|
current_step = train_dataloader_len if current_step == 0 else current_step |
|
info = ( |
|
f"Epoch: {epoch} | Global Step: {global_step} | Local Step: {current_step} // {train_dataloader_len}, " |
|
f"total_eta: {eta}, epoch_eta:{eta_epoch}, time: all:{t:.3f}, model:{t_m:.3f}, data:{t_d:.3f}, " |
|
f"lm:{t_lm:.3f}, vae:{t_vae:.3f}, lr:{lr:.3e}, Cap: {batch[5][0]}, " |
|
) |
|
info += ( |
|
f"s:({model.module.h}, {model.module.w}), " |
|
if hasattr(model, "module") |
|
else f"s:({model.h}, {model.w}), " |
|
) |
|
|
|
info += ", ".join([f"{k}:{v:.4f}" for k, v in log_buffer.output.items()]) |
|
last_tic = time.time() |
|
log_buffer.clear() |
|
data_time_all = 0 |
|
model_time_all = 0 |
|
lm_time_all = 0 |
|
vae_time_all = 0 |
|
if accelerator.is_main_process: |
|
logger.info(info) |
|
|
|
logs.update(lr=lr) |
|
if accelerator.is_main_process: |
|
accelerator.log(logs, step=global_step) |
|
|
|
global_step += 1 |
|
|
|
if loss_nan_timer > 20: |
|
raise ValueError("Loss is NaN too much times. Break here.") |
|
if global_step % config.train.save_model_steps == 0 or (time.time() - training_start_time) / 3600 > 3.8: |
|
accelerator.wait_for_everyone() |
|
if accelerator.is_main_process: |
|
os.umask(0o000) |
|
ckpt_saved_path = save_checkpoint( |
|
osp.join(config.work_dir, "checkpoints"), |
|
epoch=epoch, |
|
step=global_step, |
|
model=accelerator.unwrap_model(model), |
|
optimizer=optimizer, |
|
lr_scheduler=lr_scheduler, |
|
generator=generator, |
|
add_symlink=True, |
|
) |
|
if config.train.online_metric and global_step % config.train.eval_metric_step == 0 and step > 1: |
|
online_metric_monitor_dir = osp.join(config.work_dir, config.train.online_metric_dir) |
|
os.makedirs(online_metric_monitor_dir, exist_ok=True) |
|
with open(f"{online_metric_monitor_dir}/{ckpt_saved_path.split('/')[-1]}.txt", "w") as f: |
|
f.write(osp.join(config.work_dir, "config.py") + "\n") |
|
f.write(ckpt_saved_path) |
|
|
|
if (time.time() - training_start_time) / 3600 > 3.8: |
|
logger.info(f"Stopping training at epoch {epoch}, step {global_step} due to time limit.") |
|
return |
|
if config.train.visualize and (global_step % config.train.eval_sampling_steps == 0 or (step + 1) == 1): |
|
accelerator.wait_for_everyone() |
|
if accelerator.is_main_process: |
|
if validation_noise is not None: |
|
log_validation( |
|
accelerator=accelerator, |
|
config=config, |
|
model=model, |
|
logger=logger, |
|
step=global_step, |
|
device=accelerator.device, |
|
vae=vae, |
|
init_noise=validation_noise, |
|
) |
|
else: |
|
log_validation( |
|
accelerator=accelerator, |
|
config=config, |
|
model=model, |
|
logger=logger, |
|
step=global_step, |
|
device=accelerator.device, |
|
vae=vae, |
|
) |
|
|
|
|
|
|
|
if ( |
|
config.model.multi_scale |
|
and (train_dataloader_len - sampler.step_start // config.train.train_batch_size - step) < 30 |
|
): |
|
global_step = epoch * train_dataloader_len |
|
logger.info("Early stop current iteration") |
|
break |
|
|
|
data_time_start = time.time() |
|
|
|
if epoch % config.train.save_model_epochs == 0 or epoch == config.train.num_epochs and not config.debug: |
|
accelerator.wait_for_everyone() |
|
if accelerator.is_main_process: |
|
|
|
ckpt_saved_path = save_checkpoint( |
|
osp.join(config.work_dir, "checkpoints"), |
|
epoch=epoch, |
|
step=global_step, |
|
model=accelerator.unwrap_model(model), |
|
optimizer=optimizer, |
|
lr_scheduler=lr_scheduler, |
|
generator=generator, |
|
add_symlink=True, |
|
) |
|
|
|
online_metric_monitor_dir = osp.join(config.work_dir, config.train.online_metric_dir) |
|
os.makedirs(online_metric_monitor_dir, exist_ok=True) |
|
with open(f"{online_metric_monitor_dir}/{ckpt_saved_path.split('/')[-1]}.txt", "w") as f: |
|
f.write(osp.join(config.work_dir, "config.py") + "\n") |
|
f.write(ckpt_saved_path) |
|
accelerator.wait_for_everyone() |
|
|
|
|
|
@pyrallis.wrap() |
|
def main(cfg: SanaConfig) -> None: |
|
global train_dataloader_len, start_epoch, start_step, vae, generator, num_replicas, rank, training_start_time |
|
global load_vae_feat, load_text_feat, validation_noise, text_encoder, tokenizer |
|
global max_length, validation_prompts, latent_size, valid_prompt_embed_suffix, null_embed_path |
|
global image_size, cache_file, total_steps |
|
|
|
config = cfg |
|
args = cfg |
|
|
|
|
|
training_start_time = time.time() |
|
load_from = True |
|
if args.resume_from or config.model.resume_from: |
|
load_from = False |
|
config.model.resume_from = dict( |
|
checkpoint=args.resume_from or config.model.resume_from, |
|
load_ema=False, |
|
resume_optimizer=True, |
|
resume_lr_scheduler=True, |
|
) |
|
|
|
if args.debug: |
|
config.train.log_interval = 1 |
|
config.train.train_batch_size = min(64, config.train.train_batch_size) |
|
args.report_to = "tensorboard" |
|
|
|
os.umask(0o000) |
|
os.makedirs(config.work_dir, exist_ok=True) |
|
|
|
init_handler = InitProcessGroupKwargs() |
|
init_handler.timeout = datetime.timedelta(seconds=5400) |
|
|
|
if config.train.use_fsdp: |
|
init_train = "FSDP" |
|
from accelerate import FullyShardedDataParallelPlugin |
|
from torch.distributed.fsdp.fully_sharded_data_parallel import FullStateDictConfig |
|
|
|
set_fsdp_env() |
|
fsdp_plugin = FullyShardedDataParallelPlugin( |
|
state_dict_config=FullStateDictConfig(offload_to_cpu=False, rank0_only=False), |
|
) |
|
else: |
|
init_train = "DDP" |
|
fsdp_plugin = None |
|
|
|
accelerator = Accelerator( |
|
mixed_precision=config.model.mixed_precision, |
|
gradient_accumulation_steps=config.train.gradient_accumulation_steps, |
|
log_with=args.report_to, |
|
project_dir=osp.join(config.work_dir, "logs"), |
|
fsdp_plugin=fsdp_plugin, |
|
kwargs_handlers=[init_handler], |
|
) |
|
|
|
log_name = "train_log.log" |
|
logger = get_root_logger(osp.join(config.work_dir, log_name)) |
|
logger.info(accelerator.state) |
|
|
|
config.train.seed = init_random_seed(getattr(config.train, "seed", None)) |
|
set_random_seed(config.train.seed + int(os.environ["LOCAL_RANK"])) |
|
generator = torch.Generator(device="cpu").manual_seed(config.train.seed) |
|
|
|
if accelerator.is_main_process: |
|
pyrallis.dump(config, open(osp.join(config.work_dir, "config.yaml"), "w"), sort_keys=False, indent=4) |
|
if args.report_to == "wandb": |
|
import wandb |
|
|
|
wandb.init(project=args.tracker_project_name, name=args.name, resume="allow", id=args.name) |
|
|
|
logger.info(f"Config: \n{config}") |
|
logger.info(f"World_size: {get_world_size()}, seed: {config.train.seed}") |
|
logger.info(f"Initializing: {init_train} for training") |
|
image_size = config.model.image_size |
|
latent_size = int(image_size) // config.vae.vae_downsample_rate |
|
pred_sigma = getattr(config.scheduler, "pred_sigma", True) |
|
learn_sigma = getattr(config.scheduler, "learn_sigma", True) and pred_sigma |
|
max_length = config.text_encoder.model_max_length |
|
vae = None |
|
validation_noise = ( |
|
torch.randn(1, config.vae.vae_latent_dim, latent_size, latent_size, device="cpu", generator=generator) |
|
if getattr(config.train, "deterministic_validation", False) |
|
else None |
|
) |
|
if not config.data.load_vae_feat: |
|
vae = get_vae(config.vae.vae_type, config.vae.vae_pretrained, accelerator.device).to(torch.float16) |
|
tokenizer = text_encoder = None |
|
if not config.data.load_text_feat: |
|
tokenizer, text_encoder = get_tokenizer_and_text_encoder( |
|
name=config.text_encoder.text_encoder_name, device=accelerator.device |
|
) |
|
text_embed_dim = text_encoder.config.hidden_size |
|
else: |
|
text_embed_dim = config.text_encoder.caption_channels |
|
|
|
logger.info(f"vae type: {config.vae.vae_type}") |
|
if config.text_encoder.chi_prompt: |
|
chi_prompt = "\n".join(config.text_encoder.chi_prompt) |
|
logger.info(f"Complex Human Instruct: {chi_prompt}") |
|
|
|
os.makedirs(config.train.null_embed_root, exist_ok=True) |
|
null_embed_path = osp.join( |
|
config.train.null_embed_root, |
|
f"null_embed_diffusers_{config.text_encoder.text_encoder_name}_{max_length}token_{text_embed_dim}.pth", |
|
) |
|
if config.train.visualize and len(config.train.validation_prompts): |
|
|
|
valid_prompt_embed_suffix = f"{max_length}token_{config.text_encoder.text_encoder_name}_{text_embed_dim}.pth" |
|
validation_prompts = config.train.validation_prompts |
|
skip = True |
|
if config.text_encoder.chi_prompt: |
|
uuid_chi_prompt = hashlib.sha256(chi_prompt.encode()).hexdigest() |
|
else: |
|
uuid_chi_prompt = hashlib.sha256(b"").hexdigest() |
|
config.train.valid_prompt_embed_root = osp.join(config.train.valid_prompt_embed_root, uuid_chi_prompt) |
|
Path(config.train.valid_prompt_embed_root).mkdir(parents=True, exist_ok=True) |
|
|
|
if config.text_encoder.chi_prompt: |
|
|
|
chi_prompt_file = osp.join(config.train.valid_prompt_embed_root, "chi_prompt.txt") |
|
with open(chi_prompt_file, "w", encoding="utf-8") as f: |
|
f.write(chi_prompt) |
|
|
|
for prompt in validation_prompts: |
|
prompt_embed_path = osp.join( |
|
config.train.valid_prompt_embed_root, f"{prompt[:50]}_{valid_prompt_embed_suffix}" |
|
) |
|
if not (osp.exists(prompt_embed_path) and osp.exists(null_embed_path)): |
|
skip = False |
|
logger.info("Preparing Visualization prompt embeddings...") |
|
break |
|
if accelerator.is_main_process and not skip: |
|
if config.data.load_text_feat and (tokenizer is None or text_encoder is None): |
|
logger.info(f"Loading text encoder and tokenizer from {config.text_encoder.text_encoder_name} ...") |
|
tokenizer, text_encoder = get_tokenizer_and_text_encoder(name=config.text_encoder.text_encoder_name) |
|
|
|
for prompt in validation_prompts: |
|
prompt_embed_path = osp.join( |
|
config.train.valid_prompt_embed_root, f"{prompt[:50]}_{valid_prompt_embed_suffix}" |
|
) |
|
if "T5" in config.text_encoder.text_encoder_name: |
|
txt_tokens = tokenizer( |
|
prompt, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" |
|
).to(accelerator.device) |
|
caption_emb = text_encoder(txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask)[0] |
|
caption_emb_mask = txt_tokens.attention_mask |
|
elif ( |
|
"gemma" in config.text_encoder.text_encoder_name or "Qwen" in config.text_encoder.text_encoder_name |
|
): |
|
if not config.text_encoder.chi_prompt: |
|
max_length_all = config.text_encoder.model_max_length |
|
else: |
|
chi_prompt = "\n".join(config.text_encoder.chi_prompt) |
|
prompt = chi_prompt + prompt |
|
num_chi_prompt_tokens = len(tokenizer.encode(chi_prompt)) |
|
max_length_all = ( |
|
num_chi_prompt_tokens + config.text_encoder.model_max_length - 2 |
|
) |
|
|
|
txt_tokens = tokenizer( |
|
prompt, |
|
max_length=max_length_all, |
|
padding="max_length", |
|
truncation=True, |
|
return_tensors="pt", |
|
).to(accelerator.device) |
|
select_index = [0] + list(range(-config.text_encoder.model_max_length + 1, 0)) |
|
caption_emb = text_encoder(txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask)[0][ |
|
:, select_index |
|
] |
|
caption_emb_mask = txt_tokens.attention_mask[:, select_index] |
|
else: |
|
raise ValueError(f"{config.text_encoder.text_encoder_name} is not supported!!") |
|
|
|
torch.save({"caption_embeds": caption_emb, "emb_mask": caption_emb_mask}, prompt_embed_path) |
|
|
|
null_tokens = tokenizer( |
|
"", max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" |
|
).to(accelerator.device) |
|
if "T5" in config.text_encoder.text_encoder_name: |
|
null_token_emb = text_encoder(null_tokens.input_ids, attention_mask=null_tokens.attention_mask)[0] |
|
elif "gemma" in config.text_encoder.text_encoder_name or "Qwen" in config.text_encoder.text_encoder_name: |
|
null_token_emb = text_encoder(null_tokens.input_ids, attention_mask=null_tokens.attention_mask)[0] |
|
else: |
|
raise ValueError(f"{config.text_encoder.text_encoder_name} is not supported!!") |
|
torch.save( |
|
{"uncond_prompt_embeds": null_token_emb, "uncond_prompt_embeds_mask": null_tokens.attention_mask}, |
|
null_embed_path, |
|
) |
|
if config.data.load_text_feat: |
|
del tokenizer |
|
del text_encoder |
|
del null_token_emb |
|
del null_tokens |
|
flush() |
|
|
|
os.environ["AUTOCAST_LINEAR_ATTN"] = "true" if config.model.autocast_linear_attn else "false" |
|
|
|
|
|
train_diffusion = Scheduler( |
|
str(config.scheduler.train_sampling_steps), |
|
noise_schedule=config.scheduler.noise_schedule, |
|
predict_v=config.scheduler.predict_v, |
|
learn_sigma=learn_sigma, |
|
pred_sigma=pred_sigma, |
|
snr=config.train.snr_loss, |
|
flow_shift=config.scheduler.flow_shift, |
|
) |
|
predict_info = f"v-prediction: {config.scheduler.predict_v}, noise schedule: {config.scheduler.noise_schedule}" |
|
if "flow" in config.scheduler.noise_schedule: |
|
predict_info += f", flow shift: {config.scheduler.flow_shift}" |
|
if config.scheduler.weighting_scheme in ["logit_normal", "mode"]: |
|
predict_info += ( |
|
f", flow weighting: {config.scheduler.weighting_scheme}, " |
|
f"logit-mean: {config.scheduler.logit_mean}, logit-std: {config.scheduler.logit_std}" |
|
) |
|
logger.info(predict_info) |
|
|
|
|
|
model_kwargs = { |
|
"pe_interpolation": config.model.pe_interpolation, |
|
"config": config, |
|
"model_max_length": max_length, |
|
"qk_norm": config.model.qk_norm, |
|
"micro_condition": config.model.micro_condition, |
|
"caption_channels": text_embed_dim, |
|
"y_norm": config.text_encoder.y_norm, |
|
"attn_type": config.model.attn_type, |
|
"ffn_type": config.model.ffn_type, |
|
"mlp_ratio": config.model.mlp_ratio, |
|
"mlp_acts": list(config.model.mlp_acts), |
|
"in_channels": config.vae.vae_latent_dim, |
|
"y_norm_scale_factor": config.text_encoder.y_norm_scale_factor, |
|
"use_pe": config.model.use_pe, |
|
"linear_head_dim": config.model.linear_head_dim, |
|
"pred_sigma": pred_sigma, |
|
"learn_sigma": learn_sigma, |
|
} |
|
model = build_model( |
|
config.model.model, |
|
config.train.grad_checkpointing, |
|
getattr(config.model, "fp32_attention", False), |
|
input_size=latent_size, |
|
**model_kwargs, |
|
).train() |
|
logger.info( |
|
colored( |
|
f"{model.__class__.__name__}:{config.model.model}, " |
|
f"Model Parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M", |
|
"green", |
|
attrs=["bold"], |
|
) |
|
) |
|
|
|
if args.load_from is not None: |
|
config.model.load_from = args.load_from |
|
if config.model.load_from is not None and load_from: |
|
_, missing, unexpected, _ = load_checkpoint( |
|
config.model.load_from, |
|
model, |
|
load_ema=config.model.resume_from.get("load_ema", False), |
|
null_embed_path=null_embed_path, |
|
) |
|
logger.warning(f"Missing keys: {missing}") |
|
logger.warning(f"Unexpected keys: {unexpected}") |
|
|
|
|
|
if accelerator.distributed_type == DistributedType.FSDP: |
|
for m in accelerator._models: |
|
m.clip_grad_norm_ = types.MethodType(clip_grad_norm_, m) |
|
|
|
|
|
config.data.data_dir = config.data.data_dir if isinstance(config.data.data_dir, list) else [config.data.data_dir] |
|
config.data.data_dir = [ |
|
data if data.startswith(("https://", "http://", "gs://", "/", "~")) else osp.abspath(osp.expanduser(data)) |
|
for data in config.data.data_dir |
|
] |
|
num_replicas = int(os.environ["WORLD_SIZE"]) |
|
rank = int(os.environ["RANK"]) |
|
dataset = build_dataset( |
|
asdict(config.data), |
|
resolution=image_size, |
|
aspect_ratio_type=config.model.aspect_ratio_type, |
|
real_prompt_ratio=config.train.real_prompt_ratio, |
|
max_length=max_length, |
|
config=config, |
|
caption_proportion=config.data.caption_proportion, |
|
sort_dataset=config.data.sort_dataset, |
|
vae_downsample_rate=config.vae.vae_downsample_rate, |
|
) |
|
accelerator.wait_for_everyone() |
|
if config.model.multi_scale: |
|
drop_last = True |
|
uuid = hashlib.sha256("-".join(config.data.data_dir).encode()).hexdigest()[:8] |
|
cache_dir = osp.expanduser(f"~/.cache/_wids_batchsampler_cache") |
|
os.makedirs(cache_dir, exist_ok=True) |
|
base_pattern = ( |
|
f"{cache_dir}/{getpass.getuser()}-{uuid}-sort_dataset{config.data.sort_dataset}" |
|
f"-hq_only{config.data.hq_only}-valid_num{config.data.valid_num}" |
|
f"-aspect_ratio{len(dataset.aspect_ratio)}-droplast{drop_last}" |
|
f"dataset_len{len(dataset)}" |
|
) |
|
cache_file = f"{base_pattern}-num_replicas{num_replicas}-rank{rank}" |
|
for i in config.data.data_dir: |
|
cache_file += f"-{i}" |
|
cache_file += ".json" |
|
|
|
sampler = DistributedRangedSampler(dataset, num_replicas=num_replicas, rank=rank) |
|
batch_sampler = AspectRatioBatchSampler( |
|
sampler=sampler, |
|
dataset=dataset, |
|
batch_size=config.train.train_batch_size, |
|
aspect_ratios=dataset.aspect_ratio, |
|
drop_last=drop_last, |
|
ratio_nums=dataset.ratio_nums, |
|
config=config, |
|
valid_num=config.data.valid_num, |
|
hq_only=config.data.hq_only, |
|
cache_file=cache_file, |
|
caching=args.caching, |
|
) |
|
train_dataloader = build_dataloader(dataset, batch_sampler=batch_sampler, num_workers=config.train.num_workers) |
|
train_dataloader_len = len(train_dataloader) |
|
logger.info(f"rank-{rank} Cached file len: {len(train_dataloader.batch_sampler.cached_idx)}") |
|
else: |
|
sampler = DistributedRangedSampler(dataset, num_replicas=num_replicas, rank=rank) |
|
train_dataloader = build_dataloader( |
|
dataset, |
|
num_workers=config.train.num_workers, |
|
batch_size=config.train.train_batch_size, |
|
shuffle=False, |
|
sampler=sampler, |
|
) |
|
train_dataloader_len = len(train_dataloader) |
|
load_vae_feat = getattr(train_dataloader.dataset, "load_vae_feat", False) |
|
load_text_feat = getattr(train_dataloader.dataset, "load_text_feat", False) |
|
|
|
|
|
lr_scale_ratio = 1 |
|
if getattr(config.train, "auto_lr", None): |
|
lr_scale_ratio = auto_scale_lr( |
|
config.train.train_batch_size * get_world_size() * config.train.gradient_accumulation_steps, |
|
config.train.optimizer, |
|
**config.train.auto_lr, |
|
) |
|
optimizer = build_optimizer(model, config.train.optimizer) |
|
if config.train.lr_schedule_args and config.train.lr_schedule_args.get("num_warmup_steps", None): |
|
config.train.lr_schedule_args["num_warmup_steps"] = ( |
|
config.train.lr_schedule_args["num_warmup_steps"] * num_replicas |
|
) |
|
lr_scheduler = build_lr_scheduler(config.train, optimizer, train_dataloader, lr_scale_ratio) |
|
logger.warning( |
|
f"{colored(f'Basic Setting: ', 'green', attrs=['bold'])}" |
|
f"lr: {config.train.optimizer['lr']:.5f}, bs: {config.train.train_batch_size}, gc: {config.train.grad_checkpointing}, " |
|
f"gc_accum_step: {config.train.gradient_accumulation_steps}, qk norm: {config.model.qk_norm}, " |
|
f"fp32 attn: {config.model.fp32_attention}, attn type: {config.model.attn_type}, ffn type: {config.model.ffn_type}, " |
|
f"text encoder: {config.text_encoder.text_encoder_name}, captions: {config.data.caption_proportion}, precision: {config.model.mixed_precision}" |
|
) |
|
|
|
timestamp = time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()) |
|
|
|
if accelerator.is_main_process: |
|
tracker_config = dict(vars(config)) |
|
try: |
|
accelerator.init_trackers(args.tracker_project_name, tracker_config) |
|
except: |
|
accelerator.init_trackers(f"tb_{timestamp}") |
|
|
|
start_epoch = 0 |
|
start_step = 0 |
|
total_steps = train_dataloader_len * config.train.num_epochs |
|
|
|
|
|
if config.model.resume_from is not None and config.model.resume_from["checkpoint"] is not None: |
|
rng_state = None |
|
ckpt_path = osp.join(config.work_dir, "checkpoints") |
|
check_flag = osp.exists(ckpt_path) and len(os.listdir(ckpt_path)) != 0 |
|
if config.model.resume_from["checkpoint"] == "latest": |
|
if check_flag: |
|
checkpoints = os.listdir(ckpt_path) |
|
if "latest.pth" in checkpoints and osp.exists(osp.join(ckpt_path, "latest.pth")): |
|
config.model.resume_from["checkpoint"] = osp.realpath(osp.join(ckpt_path, "latest.pth")) |
|
else: |
|
checkpoints = [i for i in checkpoints if i.startswith("epoch_")] |
|
checkpoints = sorted(checkpoints, key=lambda x: int(x.replace(".pth", "").split("_")[3])) |
|
config.model.resume_from["checkpoint"] = osp.join(ckpt_path, checkpoints[-1]) |
|
else: |
|
config.model.resume_from["checkpoint"] = config.model.load_from |
|
|
|
if config.model.resume_from["checkpoint"] is not None: |
|
_, missing, unexpected, rng_state = load_checkpoint( |
|
**config.model.resume_from, |
|
model=model, |
|
optimizer=optimizer if check_flag else None, |
|
lr_scheduler=lr_scheduler if check_flag else None, |
|
null_embed_path=null_embed_path, |
|
) |
|
|
|
logger.warning(f"Missing keys: {missing}") |
|
logger.warning(f"Unexpected keys: {unexpected}") |
|
|
|
path = osp.basename(config.model.resume_from["checkpoint"]) |
|
try: |
|
start_epoch = int(path.replace(".pth", "").split("_")[1]) - 1 |
|
start_step = int(path.replace(".pth", "").split("_")[3]) |
|
except: |
|
pass |
|
|
|
|
|
if rng_state: |
|
logger.info("resuming randomise") |
|
torch.set_rng_state(rng_state["torch"]) |
|
torch.cuda.set_rng_state_all(rng_state["torch_cuda"]) |
|
np.random.set_state(rng_state["numpy"]) |
|
random.setstate(rng_state["python"]) |
|
generator.set_state(rng_state["generator"]) |
|
|
|
|
|
|
|
|
|
model = accelerator.prepare(model) |
|
optimizer, lr_scheduler = accelerator.prepare(optimizer, lr_scheduler) |
|
|
|
|
|
train( |
|
config=config, |
|
args=args, |
|
accelerator=accelerator, |
|
model=model, |
|
optimizer=optimizer, |
|
lr_scheduler=lr_scheduler, |
|
train_dataloader=train_dataloader, |
|
train_diffusion=train_diffusion, |
|
logger=logger, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
main() |
|
|