|
import math |
|
import random |
|
import types |
|
import time |
|
from collections import defaultdict |
|
from contextlib import nullcontext |
|
from functools import cached_property, partial |
|
from contextlib import ExitStack |
|
|
|
from numpy import mask_indices |
|
from unidisc.utils.tensor_utils import get_contiguous_blocks, get_contiguous_blocks_per_sample, get_interleaved_indices |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.utils.checkpoint |
|
from accelerate.utils import gather, gather_object |
|
from einops import rearrange |
|
from tensordict import TensorDict |
|
from torch import Tensor, nn |
|
from tqdm.auto import tqdm |
|
|
|
import model_eval |
|
import model_setup |
|
import model_utils |
|
import utils |
|
from decoupled_utils import (Profiler, barrier, dprint, get_rank, get_world_size, gprint, |
|
is_local_main_process, is_main_process, |
|
is_torch_cuda_available, is_torch_xla_available, |
|
print_memory, rprint, save_memory_profile, |
|
synchronize_device, try_except, use_dist) |
|
from unidisc.tokenizers.image_tokenizers import (decode_latents, get_image_batch, |
|
get_vae, vae_encode_image) |
|
from unidisc.utils.cuda_utils import sync_times |
|
from unidisc.utils.xla_utils import shard_output |
|
from model_utils import (Loss, ddprint, ema_update, empty_device_cache, get_chameleon_txt_indices, get_interleaved_block_mask, log, |
|
replace_nan_dict, update_histogram, update_logs, get_block_mask) |
|
from unidisc.utils.trainer_utils import TrainingState, incremental_dict_update, linear_warmup |
|
|
|
is_xla_available = is_torch_xla_available() |
|
|
|
if is_xla_available: |
|
import torch_xla |
|
from torch_xla.distributed.spmd import XLAShardedTensor |
|
|
|
|
|
|
|
|
|
def maybe_unwrap(t: torch.Tensor) -> torch.Tensor: |
|
return t.global_tensor if isinstance(t, XLAShardedTensor) else t |
|
|
|
class Diffusion: |
|
def __init__(self, config, tokenizer, device, disable_init=False): |
|
super().__init__() |
|
setup_methods = [ |
|
'init', 'to', 'get_params', 'get_vae', 'get_cond_vae', 'configure_optimizers', |
|
'_validate_configuration', 'register_signal_handler', 'on_train_start', |
|
'optimizer_step', 'init_dataloader', 'set_accelerator', 'set_callbacks', |
|
'on_train_step_end', 'init_optimizer_lr_scheduler', 'after_backward', 'checkpoint', |
|
'print_hashes', 'shortcut_return', 'reset_validation_metrics', 'unwrap_model' |
|
] |
|
for method_name in setup_methods: |
|
setattr(self, method_name, types.MethodType(getattr(model_setup, method_name), self)) |
|
|
|
utils_methods = [ |
|
'get_coord_plot', '_score_entropy', 'sample_subs_guidance', |
|
'restore_model_and_semi_ar_sample', '_reconstruction_loss', |
|
'restore_model_and_sample', 'get_score', '_staggered_score', |
|
'_analytic_update', '_denoiser_update', '_transp_transition', |
|
'eval_retokenize', 'compute_generative_perplexity', '_d3pm_loss', |
|
'_d3pm_parameterization', '_sedd_parameterization', |
|
'get_base_shapes_for_mup', 'update_histogram', '_maybe_sub_sample', |
|
'viz_images_from_dataloader', 'compute_cider' |
|
] |
|
for method_name in utils_methods: |
|
setattr(self, method_name, types.MethodType(getattr(model_utils, method_name), self)) |
|
|
|
eval_methods = [ |
|
'get_every_n_evals', 'on_validation_epoch_start', 'sample', |
|
'predict_step', 'validation_step', 'on_validation_epoch_end', |
|
'on_validation_epoch_cleanup', '_sample_prior', '_ddpm_forward', |
|
'_ddpm_update', '_ddpm_caching_update', '_sample', '_ar_sampler', |
|
'decode_batch', 'sample_transfusion', 'sample_continuous_image', |
|
'decode_sampling', '_ddpm_update_finetune_controlled_tweedie', |
|
'sample_masking', 'log_flops', "visualize_samples", "_maskgit_update", |
|
"_first_hitting_update", "update_inline_fid", "compute_inline_fid", |
|
"update_clean_fid", "compute_clean_fid_eval", "sample_for_fid", |
|
"compute_clip_score", "mauve_store_references", "zero_shot_eval_step", |
|
"zero_shot_eval_epoch_end", "get_cfg_weight", "cleanup_fid_output", |
|
"calculate_chameleon_perplexity", "get_anole_data", |
|
"update_img_to_txt_mauve_clip", "compute_mauve_entropy", |
|
"get_top_k", "compute_entropy", "get_mauve_score", "get_valid_seq", "gather_tokens", |
|
"count_valid_tokens", "compute_val_metrics_standalone", "_maskgit_nucleus_update", |
|
"get_img_text_saturation_batch", "handle_interleaved_decode", "get_interleaved_image", |
|
"auto_enhance", "get_clip_score", "get_dfn_score", "get_hpsv2_score", "get_model_likelihood_score", |
|
"get_laion_aesthetic_score", "get_rewards", "get_chameleon_score", "clear_reward_models", |
|
"get_text_likelihood_score", "get_text_reward_model_score", "save_image_text_pair" |
|
] |
|
for method_name in eval_methods: |
|
setattr(self, method_name, types.MethodType(getattr(model_eval, method_name), self)) |
|
|
|
if disable_init: |
|
pass |
|
else: |
|
model_setup.init(self, config, tokenizer, device) |
|
|
|
@cached_property |
|
def xla_mesh(self): |
|
import torch_xla.distributed.spmd as xs |
|
return xs.get_global_mesh() |
|
|
|
def on_train_resume(self): |
|
if not is_torch_xla_available(): |
|
empty_device_cache() |
|
|
|
if self.ema is not None and not self.config.trainer.use_custom_ema: |
|
self.ema.restore(self.get_params(), raise_error_if_already_restored=False) |
|
|
|
self.backbone.train() |
|
|
|
def zero_shot_update_batch(self, batch): |
|
dataset = self.config.data.train |
|
if dataset is None: |
|
return batch |
|
|
|
def get_attr(attr_name): |
|
return getattr(self.config.model, attr_name, None) |
|
|
|
if dataset == "nlphuji/flickr30k": |
|
|
|
|
|
batch['gt_input_ids'] = batch['input_ids'] |
|
image_input_ids = get_image_batch(self.config, self.get_vae(), batch, self.device) |
|
image_input_ids += self.text_vocab_size |
|
batch["input_ids"] = torch.cat([torch.zeros_like(batch['gt_input_ids'], dtype=torch.int64), image_input_ids], dim=-1).to(self.device) |
|
batch['attention_mask'] = torch.cat([torch.zeros_like(batch['gt_input_ids'], dtype=torch.bool), torch.ones_like(image_input_ids, dtype=torch.bool)], dim=-1).to(self.device) |
|
batch["modality"] = torch.cat([torch.zeros_like(batch['gt_input_ids'], dtype=torch.int64), torch.ones_like(image_input_ids, dtype=torch.int64)], dim=-1).to(self.device) |
|
elif dataset == "facebook/winoground": |
|
|
|
caption_0_input_ids = batch['caption_0_input_ids'] |
|
caption_1_input_ids = batch['caption_1_input_ids'] |
|
image_0 = batch['img_0'] |
|
image_1 = batch['img_1'] |
|
|
|
image_0_input_ids = vae_encode_image(self.config, self.get_vae(), image_0, self.device, get_attr("vae_type")) + self.text_vocab_size |
|
image_1_input_ids = vae_encode_image(self.config, self.get_vae(), image_1, self.device, get_attr("vae_type")) + self.text_vocab_size |
|
|
|
batch['input_ids_0_0'] = torch.cat([caption_0_input_ids, image_0_input_ids], dim=-1).to(self.device) |
|
batch['input_ids_0_1'] = torch.cat([caption_0_input_ids, image_1_input_ids], dim=-1).to(self.device) |
|
batch['input_ids_1_0'] = torch.cat([caption_1_input_ids, image_0_input_ids], dim=-1).to(self.device) |
|
batch['input_ids_1_1'] = torch.cat([caption_1_input_ids, image_1_input_ids], dim=-1).to(self.device) |
|
batch['attention_mask'] = torch.cat([torch.zeros_like(caption_0_input_ids, dtype=torch.bool), torch.ones_like(image_0_input_ids, dtype=torch.bool)], dim=-1).to(self.device) |
|
batch['modality'] = torch.cat([torch.zeros_like(caption_0_input_ids, dtype=torch.int64), torch.ones_like(image_0_input_ids, dtype=torch.int64)], dim=-1).to(self.device) |
|
|
|
batch["modality_mask"] = F.one_hot(batch["modality"], num_classes=2).to(torch.bool) |
|
return batch |
|
|
|
def update_batch(self, batch): |
|
if getattr(self.config.eval, 'big_seq_len_eval', False): |
|
|
|
N = self.config.model.length |
|
new_batch = dict() |
|
new_batch['input_ids'] = torch.zeros(batch['input_ids'].shape[0], N, device=self.device, dtype=batch['input_ids'].dtype) |
|
new_batch['attention_mask'] = torch.ones(batch['attention_mask'].shape[0], N, device=self.device, dtype=batch['attention_mask'].dtype) |
|
new_batch['modality'] = torch.zeros(batch['modality'].shape[0], N, device=self.device, dtype=batch['modality'].dtype) |
|
new_batch['modality'][:, N//2:] = 1 |
|
new_batch['modality_mask'] = F.one_hot(new_batch['modality'], num_classes=2).to(torch.bool) |
|
batch = new_batch |
|
return batch |
|
|
|
continuous_mode = self.config.trainer.image_mode == "continuous" |
|
if batch is None: |
|
gprint(f"Warning! Batch is None") |
|
return batch |
|
|
|
if isinstance(batch, TensorDict): |
|
batch.batch_size = (batch.batch_size[0],) |
|
|
|
if self.image_model or getattr(self.config.data, "force_image_dataset", False): |
|
text_input_ids = None |
|
if isinstance(batch, TensorDict) and (self.is_compiled or getattr(self.config.trainer, "force_convert_to_dict", False)): |
|
batch = dict(batch.items()) |
|
|
|
if "txt_input_ids" in batch or "img_input_ids" in batch: |
|
index_keys = ["img_input_ids", "txt_input_ids", "sample_ids"] |
|
for key in index_keys: |
|
if key in batch: |
|
if isinstance(batch[key], list): |
|
batch[key] = torch.stack(batch[key], dim=0) |
|
batch[key] = batch[key].to(torch.int64) |
|
|
|
index_keys = ["img_label"] |
|
for key in index_keys: |
|
if key in batch: |
|
batch[key] = batch[key].squeeze(-1) |
|
|
|
img_input_ids = batch.pop("img_input_ids") |
|
batch["input_ids"] = img_input_ids |
|
batch["attention_mask"] = torch.ones_like(img_input_ids).to(torch.bool) |
|
if "txt_input_ids" in batch: |
|
batch["input_ids"] = torch.cat([batch["txt_input_ids"], batch["input_ids"] + self.text_vocab_size], dim=-1) |
|
batch["attention_mask"] = torch.cat([batch["txt_attention_mask"], batch["attention_mask"]], dim=-1) |
|
|
|
batch["input_ids"] = batch["input_ids"].to(torch.int64) |
|
|
|
if "modality" not in batch: |
|
if getattr(self.config.trainer, "ignore_text_in_unified", False): |
|
modality = torch.ones_like(batch["input_ids"], dtype=torch.int64) |
|
else: |
|
assert self.config.model.txt_length > 0 and self.config.model.img_length > 0 |
|
modality = torch.zeros_like(batch["input_ids"], dtype=torch.int64) |
|
modality[:, -img_input_ids.shape[-1]:] = 1 |
|
batch["modality"] = modality |
|
|
|
elif (self.config.trainer.multimodal_batches or continuous_mode) and \ |
|
not getattr(self.config.trainer, "use_legacy_update_batch_fn", False): |
|
|
|
if "img" in batch: |
|
is_image_batch = (batch["modality"] == 1).all(dim=-1) |
|
image_input_ids = get_image_batch(self.config, self.get_vae(), batch, self.device) |
|
assert ((batch["modality"].sum(dim=-1) == 0) | (batch["modality"].sum(dim=-1) >= image_input_ids.shape[1])).all() |
|
|
|
if getattr(self.config.trainer, "add_label", False): |
|
assert (batch["modality"] == 1).all() |
|
batch["input_ids"][:, 1:] = torch.where(is_image_batch[:, None], image_input_ids, batch["input_ids"][:, 1:]) |
|
elif image_input_ids.ndim == 3: |
|
batch["img_emb"] = torch.where((batch["modality"] == 1)[:, :, None], image_input_ids, torch.nan) |
|
elif (batch["input_ids"][batch["modality"] == 1] == -1).all(): |
|
batch["input_ids"].masked_scatter_(batch["modality"] == 1, image_input_ids) |
|
else: |
|
batch["input_ids"] = torch.where(is_image_batch[:, None], image_input_ids, batch["input_ids"]) |
|
|
|
if getattr(self.config.trainer, "force_shift_raw_image_batches", False): |
|
assert not getattr(self.config.trainer, "force_shift_image_batches", False) |
|
batch["input_ids"] = torch.where(batch["modality"] == 1, batch["input_ids"] + self.text_vocab_size, batch["input_ids"]) |
|
else: |
|
if getattr(self.config.trainer, "add_label", False): |
|
shift_index = self.vocab_size - self.config.model.add_labels |
|
batch["input_ids"] = torch.cat([batch["label"] + shift_index, batch["input_ids"]], dim=-1) |
|
batch["attention_mask"] = torch.cat([torch.zeros_like(batch["label"], dtype=torch.bool), batch["attention_mask"]], dim=-1) |
|
batch["modality"] = torch.cat([torch.ones_like(batch["label"], dtype=torch.int64), batch["modality"]], dim=-1) |
|
assert (batch["modality"] == 1).all() |
|
|
|
batch["input_ids"] = batch["input_ids"].to(torch.int64) |
|
if "sample_ids" in batch: |
|
batch["sample_ids"] = batch["sample_ids"].to(torch.int64) |
|
|
|
if getattr(self.config.trainer, "force_shift_image_batches", False): |
|
batch["input_ids"] = torch.where(batch["modality"] == 1, batch["input_ids"] + self.text_vocab_size, batch["input_ids"]) |
|
else: |
|
if continuous_mode: |
|
assert False |
|
else: |
|
if "input_ids" in batch and not self.config.trainer.ignore_text_in_unified: |
|
assert self.config.model.unified_model |
|
assert "attention_mask" in batch |
|
text_input_ids = batch["input_ids"] |
|
|
|
image_ids = get_image_batch(self.config, self.get_vae(), batch, self.device) |
|
image_attention_mask = torch.ones_like(image_ids).to(torch.bool) |
|
|
|
if "cond_img" in batch: |
|
cond_image_ids = get_image_batch(self.config, self.get_cond_vae(), batch, self.device, use_cond=True) |
|
batch["cond_input_ids"] = cond_image_ids |
|
|
|
if text_input_ids is not None: |
|
assert batch["input_ids"].shape[1] == self.config.model.txt_length |
|
assert image_ids.shape[1] == self.config.model.img_length |
|
image_ids = image_ids + self.text_vocab_size |
|
|
|
batch["input_ids"] = torch.cat([batch["input_ids"].to(self.device), image_ids], dim=-1) |
|
batch["attention_mask"] = torch.cat([batch["attention_mask"].to(self.device), image_attention_mask], dim=-1).to(torch.bool) |
|
assert batch["input_ids"].shape[1] == batch["attention_mask"].shape[1] == self.config.model.length |
|
batch["modality"] = torch.zeros_like(batch["input_ids"], dtype=torch.int64) |
|
batch["modality"][:, -image_ids.shape[-1]:] = 1 |
|
else: |
|
assert self.unified_model is False |
|
batch["input_ids"] = image_ids |
|
batch["attention_mask"] = image_attention_mask |
|
batch["modality"] = torch.ones_like(batch["input_ids"], dtype=torch.int64) |
|
|
|
if "txt_x0_unmask" in batch and "img_x0_unmask" in batch: |
|
assert not continuous_mode |
|
batch["gt_img_input_ids"] = image_ids |
|
batch["x0_unmask"] = torch.cat([batch["txt_x0_unmask"], batch["img_x0_unmask"]], dim=-1) |
|
batch["input_ids"][~batch["x0_unmask"]] = self.mask_index |
|
|
|
if (batch["input_ids"].shape[1] != self.config.model.length) and not self.config.trainer.ar_inpainting: |
|
gprint(f"Warning! Input ids shape: {batch['input_ids'].shape}, model length: {self.config.model.length}") |
|
batch["input_ids"] = batch["input_ids"][:, : self.config.model.length] |
|
assert False, f"input ids are not the correct length input ids shape: {batch['input_ids'].shape}, model length: {self.config.model.length}" |
|
|
|
if getattr(self.config.model, "img_cond", False): |
|
assert "cond_input_ids" in batch |
|
assert not continuous_mode |
|
|
|
if "modality" in batch: |
|
batch["modality"] = batch["modality"].to(torch.int64) |
|
if self.config.trainer.multimodal_batches and batch["modality"].ndim == 2 and batch["modality"].shape[-1] == 1: |
|
batch["modality"] = batch["modality"].repeat(1, self.config.model.length) |
|
else: |
|
if self.image_model and not self.config.trainer.multimodal_batches: |
|
assert self.config.model.txt_length > 0 and self.config.model.img_length > 0 |
|
modality = torch.zeros_like(batch["input_ids"], dtype=torch.int64) |
|
modality[:, self.static_img_sl] = 1 |
|
batch["modality"] = modality |
|
elif self.config.data.txt_only: |
|
batch["modality"] = torch.zeros_like(batch["input_ids"], dtype=torch.int64) |
|
|
|
if "modality" in batch: |
|
batch["modality"][batch["modality"] == -1] = 0 |
|
assert batch["modality"].min() == 0 and batch["modality"].max() == 1 |
|
batch["modality_mask"] = F.one_hot(batch["modality"], num_classes=2).to(torch.bool) |
|
batch["batch_contains_img"] = (batch["modality"] == 1).any(dim=-1) |
|
batch['txt_sl'] = self.txt_sl(batch) |
|
batch['img_sl'] = self.img_sl(batch) |
|
|
|
if getattr(self.config.trainer, "force_remove_img_tokens", False): |
|
assert not continuous_mode |
|
batch["input_ids"] = batch["input_ids"][batch['txt_sl']] |
|
batch["attention_mask"] = batch["attention_mask"][batch['txt_sl']] |
|
|
|
if getattr(self.config.trainer, "add_label", False): |
|
assert getattr(self.config.model, "add_labels", False) |
|
assert "label" in batch |
|
batch["label"] = batch["label"].to(torch.int64) |
|
assert 0 <= batch["label"].min() and batch["label"].max() < self.config.model.add_labels |
|
shift_index = self.vocab_size - self.config.model.add_labels |
|
|
|
assert batch["input_ids"].shape[-1] == self.config.model.length |
|
if batch["label"].ndim == 1: |
|
batch["input_ids"][:, [0]] = (batch["label"] + shift_index).unsqueeze(-1) |
|
else: |
|
batch["input_ids"][:, [0]] = batch["label"] + shift_index |
|
|
|
batch["attention_mask"][:, 0] = False |
|
|
|
if isinstance(batch, dict): |
|
for key in batch.keys(): |
|
if isinstance(batch[key], torch.Tensor): |
|
batch[key] = batch[key].to(self.device) |
|
elif isinstance(batch, TensorDict): |
|
assert self.config.backbone != "gemma" |
|
batch = batch.to(self.device) |
|
|
|
if getattr(self.config.trainer, "force_full_attention_mask", False): |
|
batch["attention_mask"] = torch.ones_like(batch["attention_mask"], dtype=torch.bool) |
|
|
|
batch["attention_mask"] = batch["attention_mask"].to(torch.bool) |
|
|
|
if self.config.data.require_sample_ids: |
|
assert "sample_ids" in batch |
|
batch["sample_ids"][~(batch["attention_mask"].bool())] = -1 |
|
batch["attention_mask"][batch["sample_ids"] == -1] = False |
|
|
|
|
|
|
|
if (self.training or getattr(self.config.trainer, "force_flip_ar_val", False)) and self.config.parameterization == "ar" and getattr(self.config.trainer, "rand_flip_ar_prob", None) is not None: |
|
assert (batch["modality"][:, :self.config.model.txt_length] == 0).all() and (batch["modality"][:, self.config.model.txt_length:] == 1).all(), "Modality does not match img_before_txt configuration" |
|
batch_flip_mask = torch.rand(batch["modality"].shape[0], device=self.device) < self.config.trainer.rand_flip_ar_prob |
|
img_slice = slice(-self.config.model.img_length, None) |
|
txt_slice = slice(None, self.config.model.txt_length) |
|
|
|
for key in ["modality", "attention_mask", "input_ids"]: |
|
batch[key][batch_flip_mask] = torch.cat([batch[key][batch_flip_mask][:, img_slice], batch[key][batch_flip_mask][:, txt_slice]], dim=1) |
|
|
|
if "modality_mask" in batch: |
|
batch["modality_mask"] = F.one_hot(batch["modality"], num_classes=2).to(torch.bool) |
|
|
|
batch['txt_sl'] = None |
|
batch['img_sl'] = None |
|
batch["batch_flip_mask"] = batch_flip_mask |
|
|
|
if self.config.trainer.interleaved and "sample_ids" not in batch: |
|
batch["sample_ids"] = torch.zeros_like(batch["modality"], dtype=torch.int64) |
|
|
|
if self.config.trainer.interleaved: |
|
batch_indices, start_positions, end_positions = get_contiguous_blocks(batch["modality"]) |
|
interleaved_metadata = TensorDict({ |
|
"batch_indices": batch_indices, |
|
"start_positions": start_positions, |
|
"end_positions": end_positions |
|
}, batch_size=[]) |
|
allowed_image_sizes = (64, 256, 1024, 2304, 4096) |
|
block_sizes = (end_positions - start_positions).to(torch.int32) |
|
is_txt_block = batch["modality"][batch_indices, start_positions] == 0 |
|
is_valid_img_size = torch.isin(block_sizes, torch.tensor(allowed_image_sizes, dtype=torch.int32, device=self.device)) |
|
|
|
if not ((is_txt_block | is_valid_img_size).all()): |
|
gprint(f"WARNING: Found non-text block of size {block_sizes[~(is_txt_block | is_valid_img_size)]} in interleaved batch") |
|
|
|
if isinstance(batch, TensorDict): |
|
batch.batch_size = [] |
|
batch["interleaved_metadata"] = interleaved_metadata |
|
|
|
return batch |
|
|
|
def get_cond_dict(self, batch): |
|
ret_dict = dict() |
|
if "cond_input_ids" in batch: |
|
ret_dict["x_cond"] = batch["cond_input_ids"] |
|
|
|
if "img_label" in batch: |
|
ret_dict["label"] = batch["img_label"] |
|
|
|
if self.config.model.use_attention_mask: |
|
ret_dict["attention_mask"] = batch["attention_mask"] |
|
|
|
if self.config.trainer.multimodal_batches: |
|
ret_dict["modality"] = batch["modality"] |
|
|
|
if self.config.trainer.image_mode == "continuous": |
|
ret_dict["continuous_mode"] = True |
|
ret_dict["modality"] = batch["modality"] |
|
|
|
if self.parameterization == "ar" and "modality" in batch: |
|
ret_dict["modality"] = batch["modality"] |
|
|
|
return ret_dict |
|
|
|
def training_step(self, batch, batch_idx): |
|
batch = self.update_batch(batch) |
|
return self.compute_loss(batch, prefix="train", batch_idx=batch_idx) |
|
|
|
def q_xt(self, x, move_chance, allow_move_mask=None, return_ignore_batch_mask_for_metrics=False, mask_image_square=False, mask_text_region=False, batch=None): |
|
"""Computes the noisy sample xt. |
|
|
|
Args: |
|
x: int torch.Tensor with shape (batch_size, |
|
diffusion_model_input_length), input. |
|
move_chance: float torch.Tensor with shape (batch_size, 1). |
|
""" |
|
if self.config.backbone == "maskdit" and getattr(self.config.trainer, "force_single_timestep_per_batch", False): |
|
num_to_mask = int(x.shape[1] * move_chance[0].item()) |
|
batch_size, seq_len = x.shape |
|
random_indices = torch.rand(batch_size, seq_len, device=x.device).argsort(dim=1)[:, :num_to_mask] |
|
xt = x.scatter(1, random_indices, self.mask_index) |
|
return xt |
|
|
|
move_indices = torch.rand(*x.shape, device=x.device) < move_chance |
|
|
|
if mask_image_square: |
|
latent_dim = int(math.sqrt(self.config.model.img_length)) |
|
img_move_indices = move_indices[:, self.static_img_sl].clone().reshape(move_indices.shape[0], latent_dim, latent_dim) |
|
max_d = int(math.sqrt(self.config.model.img_length)) |
|
for b in range(move_indices.shape[0]): |
|
if move_chance[b] == 1: |
|
continue |
|
h, w = img_move_indices[b].shape |
|
d = random.randint(max_d // 2, max_d - 2) |
|
i = random.randint(0, h - d) |
|
j = random.randint(0, w - d) |
|
|
|
mask = torch.zeros_like(img_move_indices[b], dtype=torch.bool) |
|
mask[i:i+d, j:j+d] = True |
|
move_indices[b, self.static_img_sl] = mask.reshape(-1) |
|
|
|
if mask_text_region: |
|
for b in range(x.shape[0]): |
|
if move_chance[b] == 1: |
|
continue |
|
should_mask = torch.zeros_like(move_indices[b, self.static_txt_sl], dtype=torch.bool) |
|
max_valid = (x[b] == self.tokenizer.eos_token_id).nonzero()[0, 0] if self.tokenizer.eos_token_id in x[b] else x.shape[1] |
|
d = random.randint(max_valid//3, max_valid-1) |
|
start = random.randint(0, max_valid - d) |
|
should_mask[start:start+d] = True |
|
move_indices[b, self.static_txt_sl] = should_mask |
|
|
|
ignore_batch_mask_for_metrics = None |
|
should_mask_txt, should_mask_img = None, None |
|
if (mask_prob := getattr(self.config.trainer, "mask_entire_modality", None)) is not None \ |
|
and (mask_image_square is False and mask_text_region is False) and self.backbone.training: |
|
|
|
assert batch is not None |
|
batch_size, seq_len = x.shape |
|
if getattr(self.config.trainer, "mask_txt_only", False): |
|
should_mask_txt = torch.rand(batch_size, 1, device=x.device) < mask_prob |
|
should_mask_img = torch.zeros_like(should_mask_txt, device=x.device) |
|
else: |
|
should_mask_txt = torch.rand(batch_size, 1, device=x.device) < mask_prob/2 |
|
should_mask_img = torch.rand(batch_size, 1, device=x.device) < mask_prob/2 |
|
|
|
if self.config.trainer.multimodal_batches: |
|
if self.config.trainer.interleaved: |
|
batch_indices, start_positions, end_positions = get_contiguous_blocks_per_sample(batch["modality"], batch["sample_ids"]) |
|
|
|
block_size = end_positions - start_positions |
|
size_mask = block_size > 4 |
|
batch_indices, start_positions, end_positions = batch_indices[size_mask], start_positions[size_mask], end_positions[size_mask] |
|
|
|
|
|
block_counts = torch.zeros_like(batch_indices) |
|
max_num_sample_ids = torch.zeros_like(batch_indices) |
|
|
|
|
|
for i in range(len(batch_indices)): |
|
curr_sample_id = batch["sample_ids"][batch_indices[i], start_positions[i]] |
|
|
|
|
|
prev_blocks_mask = (batch_indices[:i] == batch_indices[i]) & \ |
|
(batch["sample_ids"][batch_indices[:i], start_positions[:i]] == curr_sample_id) |
|
|
|
total_in_sample = ((batch_indices == batch_indices[i]) & (batch["sample_ids"][batch_indices, start_positions] == curr_sample_id)).sum() |
|
|
|
block_counts[i] = prev_blocks_mask.sum() |
|
max_num_sample_ids[i] = total_in_sample |
|
|
|
block_prob = (block_counts + 1) / max_num_sample_ids |
|
positions = torch.arange(move_indices.shape[-1], device=move_indices.device).unsqueeze(0) |
|
mask = (positions >= start_positions.unsqueeze(1)) & (positions < end_positions.unsqueeze(1)) |
|
mask = mask & (torch.rand(batch_indices.shape[0], 1, device=x.device) < (mask_prob * block_prob * 2)[..., None]) |
|
expanded_batch_indices = batch_indices.unsqueeze(1).expand(-1, move_indices.shape[1]) |
|
|
|
|
|
accum = torch.zeros_like(move_indices, dtype=torch.int32) |
|
accum.scatter_add_(0, expanded_batch_indices, mask.int()) |
|
accum = accum.to(torch.bool) |
|
|
|
move_indices = move_indices | accum |
|
|
|
|
|
ignore_batch_mask_for_metrics = torch.zeros((move_indices.shape[0],), device=x.device, dtype=torch.bool) |
|
ignore_batch_mask_for_metrics.scatter_add_(0, batch_indices, mask.any(dim=-1)) |
|
else: |
|
|
|
|
|
both_mask = should_mask_txt & should_mask_img |
|
should_mask_txt = torch.where(both_mask, False, should_mask_txt) |
|
should_mask_img = torch.where(both_mask, False, should_mask_img) |
|
move_indices = torch.where(should_mask_txt, batch["modality_mask"][..., 0], move_indices) |
|
move_indices = torch.where(should_mask_img, batch["modality_mask"][..., 1], move_indices) |
|
ignore_batch_mask_for_metrics = should_mask_img | should_mask_txt |
|
else: |
|
both_mask = should_mask_txt & should_mask_img |
|
should_mask_txt[both_mask] = False |
|
should_mask_img[both_mask] = False |
|
should_mask_img[batch["txt_sl"].all(dim=-1)] = False |
|
move_indices[:, self.static_txt_sl] = torch.where(should_mask_txt, True, move_indices[:, self.static_txt_sl]) |
|
move_indices[:, self.static_img_sl] = torch.where(should_mask_img, True, move_indices[:, self.static_img_sl]) |
|
ignore_batch_mask_for_metrics = should_mask_img | should_mask_txt |
|
|
|
joint_ar_nar_mask = None |
|
if self.config.trainer.joint_ar_nar_prob is not None and self.training: |
|
batch_size = x.shape[0] |
|
current_prob = linear_warmup( |
|
current_step=self.global_step, |
|
warmup_steps=self.config.trainer.joint_ar_nar_prob_warmup_steps, |
|
final_value=self.config.trainer.joint_ar_nar_prob, |
|
initial_value=1.0 |
|
) |
|
joint_ar_nar_mask = torch.rand(batch_size, device=x.device) < current_prob |
|
move_indices = torch.where(joint_ar_nar_mask[:, None], False, move_indices) |
|
|
|
if self.config.trainer.add_label: |
|
move_indices[:, 0] = False |
|
|
|
if self.config.trainer.first_token_dropout is not None and self.training: |
|
_initial_mask = torch.rand(x.shape[0], device=x.device) < self.config.trainer.first_token_dropout |
|
move_indices[:, 0] = torch.where(_initial_mask, True, move_indices[:, 0]) |
|
if ignore_batch_mask_for_metrics is None: |
|
ignore_batch_mask_for_metrics = _initial_mask |
|
else: |
|
ignore_batch_mask_for_metrics = ignore_batch_mask_for_metrics | _initial_mask |
|
|
|
if allow_move_mask is not None: |
|
move_indices = move_indices & allow_move_mask |
|
|
|
if getattr(self.config.trainer, "discrete_diffusion_mode", "absorbing") == "uniform": |
|
if getattr(self.config.model, "force_argmax_valid_indices", False): |
|
assert self.mask_index == self.text_vocab_size - 1 |
|
text_random_tokens = torch.randint(0, self.text_vocab_size - 1, size=x.shape, device=x.device) |
|
img_random_tokens = torch.randint(self.text_vocab_size, self.vocab_size, size=x.shape, device=x.device) |
|
random_tokens = torch.where(batch["modality_mask"][..., 0], text_random_tokens, img_random_tokens) |
|
assert not torch.any(random_tokens == self.mask_index) |
|
else: |
|
random_tokens = torch.randint(0, vocab_size, size=x.shape, device=x.device) |
|
random_tokens = torch.where(random_tokens == self.mask_index, random_tokens + 1, random_tokens) |
|
xt = torch.where(move_indices, random_tokens, x) |
|
else: |
|
xt = torch.where(move_indices, self.mask_index, x) |
|
|
|
if self.parameterization == "ar": |
|
xt = x.clone() |
|
|
|
if return_ignore_batch_mask_for_metrics: |
|
return xt, ignore_batch_mask_for_metrics, joint_ar_nar_mask, should_mask_txt, should_mask_img, move_indices |
|
else: |
|
return xt |
|
|
|
def _sample_t(self, n, device): |
|
if self.config.backbone == "maskdit" and getattr(self.config.trainer, "force_single_timestep_per_batch", False): |
|
_eps_t = torch.rand(1, device=device).repeat(n) |
|
else: |
|
_eps_t = torch.rand(n, device=device) |
|
if self.config.trainer.joint_ar_nar_timestep_warmup_steps is not None: |
|
max_t = linear_warmup( |
|
current_step=self.global_step, |
|
warmup_steps=self.config.trainer.joint_ar_nar_timestep_warmup_steps, |
|
final_value=1, |
|
initial_value=0, |
|
start_step=0 |
|
) |
|
_eps_t = _eps_t * max_t |
|
if max_t == 1: |
|
offset = torch.arange(n, device=device) / n |
|
_eps_t = (_eps_t / n + offset) % 1 |
|
|
|
elif self.antithetic_sampling: |
|
offset = torch.arange(n, device=device) / n |
|
_eps_t = (_eps_t / n + offset) % 1 |
|
|
|
if getattr(self.config.trainer, "force_timestep", None) is not None: |
|
_eps_t[:] = self.config.trainer.force_timestep |
|
elif getattr(self.config.eval, "ar_inpainting_force_val", None) is not None: |
|
_eps_t[:] = self.config.eval.ar_inpainting_force_val |
|
|
|
t = (1 - self.sampling_eps) * _eps_t + self.sampling_eps |
|
if self.importance_sampling: |
|
return self.noise.importance_sampling_transformation(t) |
|
return t.to(torch.float32) |
|
|
|
def _subs_parameterization(self, logits, xt, batch=None, modality=None, **kwargs): |
|
|
|
if not self.allow_slicing: |
|
logits = logits.clone() |
|
|
|
logits[..., self.mask_index] += self.neg_infinity |
|
if getattr(self.config.model, "force_argmax_valid_indices", False): |
|
if self.config.trainer.multimodal_batches: |
|
_txt_sl = batch["txt_sl"] if modality is None else modality == 0 |
|
_img_sl = batch["img_sl"] if modality is None else modality == 1 |
|
logits[..., self.text_vocab_size:] = torch.where(_txt_sl[..., None], self.neg_infinity, logits[..., self.text_vocab_size:]) |
|
logits[..., :self.text_vocab_size] = torch.where(_img_sl[..., None], self.neg_infinity, logits[..., :self.text_vocab_size]) |
|
else: |
|
logits[..., self.static_txt_sl, self.text_vocab_size:] = self.neg_infinity |
|
logits[..., self.static_img_sl, :self.text_vocab_size] = self.neg_infinity |
|
|
|
|
|
|
|
logits = logits - torch.logsumexp(logits, dim=-1, keepdim=True) |
|
|
|
if self.parameterization != "ar" and xt is not None: |
|
|
|
|
|
|
|
|
|
unmasked_indices = xt != self.mask_index |
|
if not self.allow_slicing: |
|
logits = torch.where(unmasked_indices.unsqueeze(-1), torch.full_like(logits, self.neg_infinity), logits) |
|
logits = torch.where( |
|
unmasked_indices.unsqueeze(-1) & (torch.arange(logits.size(-1)).to(logits.device) == xt.unsqueeze(-1)), |
|
torch.zeros_like(logits), |
|
logits |
|
) |
|
else: |
|
logits[unmasked_indices] = self.neg_infinity |
|
logits[unmasked_indices, xt[unmasked_indices]] = 0 |
|
|
|
return logits |
|
|
|
def _process_sigma(self, sigma): |
|
if sigma is None: |
|
assert (self.parameterization == "ar" or self.config.trainer.ar_llm_loss) or self.config.trainer.allow_null_sigma |
|
return sigma |
|
|
|
if sigma.ndim > 1 and not self.config.trainer.image_mode == "continuous": |
|
sigma = sigma.squeeze(-1) |
|
assert sigma.ndim == 1, sigma.shape |
|
|
|
if not self.time_conditioning and getattr(self.config.model, "force_time_conditioning", False): |
|
sigma = torch.zeros_like(sigma) |
|
|
|
return sigma |
|
|
|
def forward( |
|
self, |
|
x, |
|
sigma, |
|
batch=None, |
|
forward_attention_mask=None, |
|
return_additional_loss=False, |
|
x_img_emb=None, |
|
disable_ar_shift=False, |
|
continuous_mode=False, |
|
joint_ar_nar_mask=None, |
|
return_logits=False, |
|
block_mask=None, |
|
update_cache_slice=None, |
|
**kwargs, |
|
): |
|
"""Returns log score.""" |
|
sigma = self._process_sigma(sigma) |
|
if self.config.trainer.image_mode == "continuous": assert "modality" in kwargs |
|
should_autocast = (((self.config.trainer.disable_forward_autocast_during_eval and self.backbone.training) is False) and (self.dtype != torch.float32)) |
|
with ExitStack() as stack: |
|
if should_autocast: |
|
stack.enter_context(torch.autocast(device_type=self.device.type, dtype=self.dtype)) |
|
|
|
orig_modality = None |
|
if self.config.backbone == "elm": |
|
if getattr(self.config.trainer, "print_llm_ppl", False): |
|
_labels = x.clone() |
|
_labels[~forward_attention_mask] = -100 |
|
kwargs['labels'] = _labels |
|
|
|
if "modality" in kwargs: |
|
if self.config.mode == "eval": orig_modality = kwargs.pop("modality") |
|
else: kwargs.pop("modality") |
|
|
|
if "modality_mask" in kwargs: kwargs.pop("modality_mask") |
|
if "x0" in kwargs: kwargs.pop("x0") |
|
if "start_pos" in kwargs: kwargs.pop("start_pos") |
|
if "sample_ids" in kwargs: kwargs.pop("sample_ids") |
|
|
|
output = self.backbone(input_ids=x, **kwargs) |
|
|
|
if self.config.mode == "eval": kwargs["modality"] = orig_modality |
|
|
|
if isinstance(output, Tensor): |
|
logits = output |
|
else: |
|
logits = output.logits |
|
|
|
if getattr(self.config.trainer, "print_llm_ppl", False): |
|
rprint(f"AR PPL: {torch.exp(output.loss)}") |
|
else: |
|
if self.config.trainer.compile == 'max-autotune' and not is_xla_available: |
|
torch.compiler.cudagraph_mark_step_begin() |
|
|
|
logits = self.backbone(x, sigma, continuous_mode=continuous_mode, x_img_emb=x_img_emb, block_mask=block_mask, update_cache_slice=update_cache_slice, **kwargs) |
|
if self.config.trainer.force_bf16_eval: |
|
logits = logits.to(torch.bfloat16) |
|
|
|
if continuous_mode: |
|
assert self.parameterization == "ar" |
|
logits, logits_img = logits |
|
|
|
if self.config.trainer.ar_shift and not disable_ar_shift: |
|
|
|
|
|
logits = logits[:, :-1] |
|
xt = x[:, 1:] |
|
if orig_modality is not None and self.config.mode == 'eval': |
|
orig_modality = orig_modality[:, 1:] |
|
else: |
|
xt = x |
|
|
|
if self.config.trainer.low_precision_loss: |
|
logits = logits.to(self.dtype) |
|
if continuous_mode: |
|
logits_img = logits_img.to(self.dtype) |
|
|
|
if self.parameterization == "planner": |
|
return logits |
|
elif self.config.trainer.ar_llm_loss: |
|
assert not self.parameterization == "ar" |
|
model_output = self._subs_parameterization(logits, xt=xt, modality=orig_modality), logits |
|
if is_xla_available: shard_output(model_output[0], self.xla_mesh) |
|
if is_xla_available: shard_output(model_output[1], self.xla_mesh) |
|
return model_output if return_additional_loss else model_output[0] |
|
elif self.parameterization == "ar": |
|
if not getattr(self.config.trainer, "use_orig_unidisc_dit", False): |
|
logits = torch.where( |
|
torch.arange(logits.shape[-1], device=logits.device)[None, None, :] == self.mask_index, self.neg_infinity, logits |
|
) |
|
|
|
_modality = kwargs.get("modality") if batch is None else batch.get("modality") |
|
|
|
|
|
if getattr(self.config.model, "force_argmax_valid_indices", False) and _modality.shape[1] == (logits.shape[1] + 1): |
|
if not self.allow_slicing: |
|
logits = logits.clone() |
|
|
|
logits[..., self.text_vocab_size:] = torch.where( |
|
(kwargs.get("modality") == 0)[..., 1:, None], torch.finfo(logits.dtype).min, logits[..., self.text_vocab_size:] |
|
) |
|
logits[..., :self.text_vocab_size] = torch.where( |
|
(kwargs.get("modality") == 1)[..., 1:, None], torch.finfo(logits.dtype).min, logits[..., :self.text_vocab_size] |
|
) |
|
|
|
logits = logits.log_softmax(-1) |
|
|
|
if continuous_mode: |
|
return (logits, logits_img) |
|
elif self.parameterization == "subs": |
|
if return_logits: |
|
return logits |
|
model_output = self._subs_parameterization(logits, xt=xt, batch=batch, **kwargs) |
|
if is_xla_available: shard_output(model_output, self.xla_mesh) |
|
return model_output |
|
elif self.parameterization == "sedd": |
|
return self._sedd_parameterization(logits=logits, xt=x, sigma=sigma) |
|
elif self.parameterization == "d3pm": |
|
return self._d3pm_parameterization(logits=logits) |
|
|
|
return logits |
|
|
|
def compute_loss(self, batch, prefix, batch_idx=-1): |
|
if not is_xla_available and ((self.current_run_fwd_bwd_pass == 0 and self.config.mode == 'train') or batch_idx == 0): |
|
self.visualize_samples(batch, batch_idx, split=prefix) |
|
if getattr(self.config.trainer, 'overfit_on_first_batch', False): |
|
if batch_idx <= 0: |
|
|
|
self.overfit_batch = batch.copy() |
|
else: |
|
batch = self.overfit_batch |
|
|
|
kwargs = self.get_cond_dict(batch) |
|
modality_mask = batch.get("modality_mask", None) |
|
(input_tokens, output_tokens, attention_mask) = self._maybe_sub_sample(batch["input_ids"], batch.get("attention_mask", None)) |
|
|
|
continuous_mode = self.config.trainer.image_mode == "continuous" |
|
joint_ar_nar_mask, modality = None, None |
|
if continuous_mode: |
|
assert 'modality' in batch |
|
x0, img_emb, attention_mask, modality = ( |
|
batch["input_ids"], |
|
batch["img_emb"], |
|
batch["attention_mask"], |
|
batch["modality"], |
|
) |
|
xt = x0 |
|
B, N_tot, C = img_emb.shape |
|
|
|
noise_scheduler = self.get_vae().scheduler |
|
noise = torch.randn_like(img_emb) |
|
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (B,), device=img_emb.device).long() |
|
img_timesteps = timesteps.unsqueeze(-1).expand(-1, N_tot).to(self.dtype) |
|
zero_timesteps = torch.zeros_like(img_timesteps) |
|
unet_conditioning = torch.where(modality == 1, img_timesteps, zero_timesteps) |
|
|
|
|
|
x_img_emb = noise_scheduler.add_noise(img_emb, noise, timesteps).to(self.dtype) |
|
|
|
if noise_scheduler.config.prediction_type == "epsilon": |
|
target = noise |
|
elif noise_scheduler.config.prediction_type == "v_prediction": |
|
target = noise_scheduler.get_velocity(img_emb, noise, timesteps) |
|
elif noise_scheduler.config.prediction_type: |
|
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
|
target = target.to(self.dtype) |
|
else: |
|
unet_conditioning, xt, x0, x_img_emb, modality_mask = None, None, input_tokens, None, batch.get("modality_mask", None) |
|
if self.parameterization != "ar": |
|
t = self._sample_t(x0.shape[0], x0.device) |
|
if self.T > 0: |
|
t = (t * self.T).to(torch.int) |
|
t = t / self.T |
|
t += 1 / self.T |
|
|
|
if self.change_of_variables: |
|
unet_conditioning = t[:, None] |
|
f_T = torch.log1p(-torch.exp(-self.noise.sigma_max)) |
|
f_0 = torch.log1p(-torch.exp(-self.noise.sigma_min)) |
|
move_chance = torch.exp(f_0 + t * (f_T - f_0)) |
|
move_chance = move_chance[:, None] |
|
else: |
|
|
|
sigma, dsigma = self.noise(t) |
|
unet_conditioning = sigma[:, None] |
|
move_chance = 1 - torch.exp(-sigma[:, None]) |
|
|
|
xt, ignore_batch_mask_for_metrics, joint_ar_nar_mask, should_mask_txt, should_mask_img, move_indices = self.q_xt(x0, move_chance, return_ignore_batch_mask_for_metrics=True, batch=batch) |
|
if (self.config.model.flex_attention_img_masking_prob is not None or self.config.model.flex_attention_txt_masking_prob is not None) and self.backbone.training: |
|
assert xt.shape[1] == (self.config.model.img_length + self.config.model.txt_length) |
|
txt_batch_attn_dropout = torch.rand(xt.shape[0], device=xt.device) < self.config.model.flex_attention_txt_masking_prob |
|
img_batch_attn_dropout = torch.rand(xt.shape[0], device=xt.device) < self.config.model.flex_attention_img_masking_prob |
|
|
|
|
|
txt_batch_attn_dropout = txt_batch_attn_dropout & ~should_mask_txt.squeeze(-1) |
|
img_batch_attn_dropout = img_batch_attn_dropout & ~should_mask_img.squeeze(-1) |
|
kwargs['block_mask'] = get_block_mask(txt_batch_attn_dropout, img_batch_attn_dropout, self.config.model.txt_length, xt.shape[0], xt.shape[1], xt.device) |
|
|
|
|
|
ignore_batch_mask_for_metrics = ignore_batch_mask_for_metrics | (txt_batch_attn_dropout | img_batch_attn_dropout).unsqueeze(-1) |
|
|
|
if getattr(self.config.trainer, "interleaved_training_flex_attention", False): |
|
kwargs['block_mask'] = get_interleaved_block_mask(batch["sample_ids"], batch_size=xt.shape[0], seq_len=xt.shape[1], device=xt.device) |
|
kwargs['sample_ids'] = batch["sample_ids"] |
|
|
|
elif self.config.trainer.ar_inpainting: |
|
x0 = torch.cat([x0, x0], dim=1) |
|
kwargs['modality'] = torch.cat([kwargs['modality'], kwargs['modality']], dim=1) |
|
attention_mask = torch.cat([torch.zeros_like(attention_mask, dtype=attention_mask.dtype), torch.ones_like(attention_mask, dtype=attention_mask.dtype)], dim=1) |
|
modality_mask = torch.cat([modality_mask, modality_mask], dim=1) |
|
min_val, max_val = 0.0, 1.0 |
|
n = x0.shape[0] |
|
_eps_t = torch.rand(n, device=self.device) |
|
offset = torch.arange(n, device=self.device) / n |
|
_eps_t = (_eps_t / n + offset) % 1 |
|
t = (max_val - min_val) * _eps_t + min_val |
|
if getattr(self.config.eval, "ar_inpainting_force_val", None) is not None: |
|
t = torch.full_like(t, getattr(self.config.eval, "ar_inpainting_force_val"), dtype=t.dtype, device=t.device) |
|
move_indices = torch.rand(*x0.shape, device=x0.device) < t[:, None] |
|
move_indices[:, x0.shape[1] // 2:] = False |
|
x0 = torch.where(move_indices, self.mask_index, x0) |
|
xt = x0 |
|
else: |
|
xt = x0 |
|
if (self.training or getattr(self.config.trainer, "force_flip_ar_val", False)) and self.config.trainer.rand_ar_modality_dropout is not None: |
|
assert not is_xla_available |
|
xt = xt.clone() |
|
batch_modality_dropout = torch.rand(xt.shape[0], device=xt.device) < self.config.trainer.rand_ar_modality_dropout |
|
first_modality = batch["modality"][:, 0] |
|
first_modality_mask = batch["modality"] == first_modality[:, None] |
|
xt = torch.where(first_modality_mask & batch_modality_dropout[:, None], self.mask_index, xt) |
|
attention_mask = torch.where(first_modality_mask & batch_modality_dropout[:, None], False, attention_mask) |
|
true_logits = None |
|
model_output = self.forward( |
|
xt, unet_conditioning, return_additional_loss=True, batch=batch, x_img_emb=x_img_emb, joint_ar_nar_mask=joint_ar_nar_mask, **kwargs |
|
) |
|
if isinstance(model_output, tuple): |
|
if continuous_mode: |
|
model_output, img_output = model_output |
|
B, _, C = img_output.shape |
|
|
|
x0 = x0[modality==0].reshape(B, -1) |
|
xt = xt[modality==0].reshape(B, -1) |
|
attention_mask = torch.ones_like(x0, dtype=torch.bool) |
|
img_output = img_output[modality==1].reshape(B, -1, C) |
|
target = target[modality==1].reshape(B, -1, C) |
|
else: |
|
model_output, true_logits = model_output |
|
|
|
to_dtype = self.dtype if self.config.trainer.low_precision_loss else torch.float32 |
|
model_output = model_output.to(to_dtype) |
|
if true_logits is not None: |
|
true_logits = true_logits.to(self.dtype) |
|
|
|
if continuous_mode: |
|
img_output = img_output.to(to_dtype) |
|
target = target.to(to_dtype) |
|
|
|
|
|
|
|
|
|
if self.config.trainer.ar_shift: |
|
x0 = x0[:, 1:] |
|
xt = xt[:, 1:] |
|
attention_mask = attention_mask[:, 1:] |
|
if modality_mask is not None: modality_mask = modality_mask[:, 1:] |
|
if modality is not None: modality = modality[:, 1:] |
|
|
|
if not self.is_compiled: |
|
utils.print_nans(model_output, "model_output") |
|
|
|
if self.parameterization == "sedd": |
|
return dsigma[:, None] * self._score_entropy(model_output, sigma[:, None], xt, x0) |
|
elif self.parameterization == "planner": |
|
return F.binary_cross_entropy_with_logits(model_output.squeeze(-1), move_indices.float()).mean() |
|
|
|
diffusion_loss = None |
|
if self.T > 0: |
|
diffusion_loss = self._d3pm_loss(model_output=model_output, xt=xt, x0=x0, t=t) |
|
if self.parameterization == "d3pm": |
|
reconstruction_loss = self._reconstruction_loss(x0) |
|
elif self.parameterization == "subs" or self.parameterization == "ar": |
|
reconstruction_loss = 0 |
|
|
|
|
|
if self.parameterization == "ar": |
|
if getattr(self.config.trainer, "use_orig_unidisc_dit", False): |
|
return self.shortcut_return(model_output, x0, attention_mask, prefix) |
|
else: |
|
log_p_theta = model_output.gather(-1, x0[:, :, None])[:, :, 0] |
|
else: |
|
|
|
log_p_theta = torch.gather(input=model_output, dim=-1, index=x0[:, :, None]).squeeze(-1) |
|
|
|
if self.change_of_variables or self.importance_sampling: |
|
return log_p_theta * torch.log1p(-torch.exp(-self.noise.sigma_min)) |
|
|
|
if self.parameterization == "ar" or getattr(self.config.trainer, "no_ce_weighting", False): |
|
std_weighting = 1 |
|
else: |
|
std_weighting = (dsigma / torch.expm1(sigma))[:, None] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss = -log_p_theta * std_weighting |
|
if not (self.parameterization == "ar" or (self.config.trainer.ar_llm_loss and joint_ar_nar_mask is None) or getattr(self.config.trainer, "no_ce_weighting", False)): |
|
gamma = getattr(self.config.trainer, "softmin_snr", None) |
|
if gamma is not None: |
|
softmin_weighting = (dsigma / (torch.expm1(sigma) + (1 / gamma)))[:, None] |
|
loss = -log_p_theta * softmin_weighting |
|
|
|
if diffusion_loss is not None: |
|
assert self.T > 0 |
|
loss = diffusion_loss |
|
|
|
std_loss = -log_p_theta * std_weighting |
|
loss_dict = dict(std_loss=std_loss.detach(), extra_losses=dict()) |
|
|
|
if self.config.trainer.log_seperate_modal_losses: |
|
assert not continuous_mode |
|
loss_dict.update( |
|
dict( |
|
std_txt_loss=(std_loss.detach() * modality_mask[..., 0] * attention_mask), |
|
std_img_loss=(std_loss.detach() * modality_mask[..., 1] * attention_mask) |
|
) |
|
) |
|
|
|
if getattr(self.config.trainer, "mask_entire_modality", None) is not None and self.backbone.training and not self.config.parameterization == "ar": |
|
loss_dict['batch_ignore_loss'] = ignore_batch_mask_for_metrics.squeeze(-1) |
|
|
|
if joint_ar_nar_mask is not None: |
|
if "batch_ignore_loss" in loss_dict: |
|
loss_dict["batch_ignore_loss"] = loss_dict["batch_ignore_loss"] | joint_ar_nar_mask |
|
else: |
|
loss_dict["batch_ignore_loss"] = joint_ar_nar_mask |
|
|
|
if (self.config.trainer.multimodal_batches or (self.config.trainer.text_loss_weight is not None and self.config.trainer.img_loss_weight is not None)) and not continuous_mode: |
|
txt_mask = modality_mask[..., 0] & attention_mask |
|
img_mask = modality_mask[..., 1] & attention_mask |
|
txt_count = txt_mask.sum() |
|
img_count = img_mask.sum() |
|
total_count = txt_count + img_count |
|
txt_frac = txt_count / total_count |
|
img_frac = img_count / total_count |
|
loss_dict["extra_losses"]["trainer/img_frac"] = img_frac |
|
loss_dict["extra_losses"]["trainer/txt_frac"] = txt_frac |
|
loss_dict["extra_losses"]["trainer/attention_mask_valid_frac"] = attention_mask.sum() / attention_mask.numel() |
|
if "batch_ignore_loss" in loss_dict: |
|
loss_dict["extra_losses"]["trainer/ignore_batch_metrics_frac"] = loss_dict["batch_ignore_loss"].sum() / loss_dict["batch_ignore_loss"].numel() |
|
|
|
if joint_ar_nar_mask is not None: |
|
pass |
|
elif self.config.trainer.text_loss_weight is not None and self.config.trainer.img_loss_weight is not None: |
|
assert not continuous_mode |
|
loss = loss * attention_mask |
|
txt_loss = ( |
|
loss[txt_mask].sum() / txt_count |
|
) * txt_frac * self.config.trainer.text_loss_weight |
|
img_loss = ( |
|
loss[img_mask].sum() / img_count |
|
) * img_frac * self.config.trainer.img_loss_weight |
|
|
|
if getattr(self.config.trainer, "set_max_txt_loss_ratio", None) is not None and not (torch.isnan(img_loss).any() or torch.isnan(txt_loss).any()): |
|
max_txt_loss = getattr(self.config.trainer, "set_max_txt_loss_ratio", 1.5) * img_loss.detach() |
|
scale = torch.minimum(torch.tensor(1.0, device=txt_loss.device), max_txt_loss / (txt_loss.detach() + 1e-8)) |
|
txt_loss = txt_loss * scale |
|
|
|
txt_loss = torch.nan_to_num(txt_loss, nan=0.0) |
|
img_loss = torch.nan_to_num(img_loss, nan=0.0) |
|
|
|
if getattr(self.config.trainer, "force_remove_img_tokens", False): |
|
img_loss = torch.tensor(0, device=loss.device, dtype=loss.dtype) |
|
|
|
loss = txt_loss + img_loss |
|
loss_dict.update(dict(txt_loss=txt_loss.clone().detach(), img_loss=img_loss.clone().detach())) |
|
|
|
elif continuous_mode: |
|
img_loss = F.mse_loss(img_output, target) |
|
|
|
if attention_mask[:, self.static_txt_sl].numel() == 0: |
|
|
|
txt_loss = (loss[:, self.static_txt_sl] * attention_mask[:, self.static_txt_sl]).sum() |
|
else: |
|
txt_loss = (loss[:, self.static_txt_sl] * attention_mask[:, self.static_txt_sl]).sum() / attention_mask[:, self.static_txt_sl].sum() |
|
loss = txt_loss + img_loss * self.config.trainer.image_loss_weight |
|
loss_dict.update(dict(img_loss=img_loss.clone().detach(), txt_loss=txt_loss.clone().detach())) |
|
else: |
|
_attention_mask = torch.ones_like(attention_mask) if getattr(self.config.trainer, "force_full_attention_mask_loss_only", False) else attention_mask |
|
loss = (loss * _attention_mask).sum() / _attention_mask.sum() |
|
loss = torch.nan_to_num(loss, nan=0.0) |
|
|
|
ar_loss = None |
|
if self.config.trainer.ar_llm_loss: |
|
assert not continuous_mode |
|
valid_loss = xt == self.mask_index |
|
_labels = x0.clone() |
|
_labels = torch.where(valid_loss, _labels, -1) |
|
_labels = torch.where(~attention_mask.to(torch.bool), -1, _labels) |
|
|
|
_logits = true_logits |
|
_logits[:, :, self.mask_index] += self.neg_infinity |
|
|
|
if getattr(self.config.model, "force_argmax_valid_indices", False): |
|
assert not self.config.trainer.multimodal_batches |
|
_logits[:, self.static_txt_sl, self.text_vocab_size:] = torch.finfo(_logits.dtype).min |
|
_logits[:, self.static_img_sl, : self.text_vocab_size] = torch.finfo(_logits.dtype).min |
|
|
|
_logits = _logits.contiguous().view(-1, _logits.shape[-1]) |
|
_labels = _labels.contiguous().view(-1) |
|
|
|
if self.config.trainer.ar_print_loss: |
|
_labels = _labels.to(_logits.device) |
|
ce_loss = loss_fct(_logits, _labels) |
|
loss_fct = nn.CrossEntropyLoss(reduction='none') |
|
ce_loss = ce_loss.mean(dim=-1) |
|
if hasattr(self, 'histogram') is False: |
|
self.histogram = {} |
|
|
|
update_histogram(self.histogram, t, ce_loss) |
|
rprint(f"ELM loss: move: {move_chance}, t:{t}, {ce_loss}") |
|
|
|
loss_fct = nn.CrossEntropyLoss(ignore_index=-1, reduction='none' if joint_ar_nar_mask is not None else 'mean') |
|
ce_loss = loss_fct(_logits, _labels) |
|
loss_dict["extra_losses"]["trainer/ce_loss"] = ce_loss |
|
ar_loss = ce_loss |
|
|
|
if joint_ar_nar_mask is not None: |
|
__true_logits = true_logits.clone() |
|
__true_logits = torch.where(torch.arange(true_logits.shape[-1], device=true_logits.device)[None, None, :] == self.mask_index, self.neg_infinity, __true_logits) |
|
log_softmax = __true_logits.log_softmax(-1) |
|
ar_loss = -log_softmax.gather(-1, x0[:, :, None])[:, :, 0] |
|
|
|
assert ar_loss is not None |
|
assert ar_loss.ndim == 2 |
|
assert loss.ndim == 2 |
|
ar_loss_weight = joint_ar_nar_mask.sum(dim=0) / joint_ar_nar_mask.shape[0] |
|
nar_loss_weight = 1 - ar_loss_weight |
|
loss_dict["extra_losses"]["trainer/ar_loss_weight"] = ar_loss_weight.detach().float() |
|
loss_dict["extra_losses"]["trainer/nar_loss_weight"] = nar_loss_weight.detach().float() |
|
loss_dict["extra_losses"]["trainer/ce_loss"] = ar_loss.mean().detach().float() |
|
ar_loss = (ar_loss * ar_loss_weight) * attention_mask |
|
nar_loss = (loss * nar_loss_weight) * attention_mask |
|
valid_count = attention_mask.sum() |
|
if not is_xla_available: |
|
ar_valid_count = attention_mask[joint_ar_nar_mask].sum() |
|
nar_valid_count = attention_mask[~joint_ar_nar_mask].sum() |
|
loss_dict["extra_losses"]["trainer/ar_loss"] = (ar_loss[joint_ar_nar_mask].sum() / ar_valid_count).detach().float() |
|
loss_dict["extra_losses"]["trainer/nar_loss"] = (loss[~joint_ar_nar_mask].sum() / nar_valid_count).detach().float() |
|
loss_dict["extra_losses"]["trainer/ar_ppl"] = torch.exp(loss_dict["extra_losses"]["trainer/ar_loss"]).detach().float() |
|
loss_dict["extra_losses"]["trainer/nar_ppl"] = torch.exp(loss_dict["extra_losses"]["trainer/nar_loss"]).detach().float() |
|
loss = (torch.where(joint_ar_nar_mask[:, None], ar_loss, nar_loss).sum() / valid_count) + weighted_z_loss |
|
elif ar_loss is not None: |
|
loss = ar_loss |
|
|
|
loss_dict = dict(loss=loss, **loss_dict) |
|
std_loss = loss_dict.get("std_loss", 0) |
|
std_nlls = std_loss * attention_mask |
|
|
|
if "batch_ignore_loss" in loss_dict: |
|
attention_mask = torch.where(loss_dict['batch_ignore_loss'][:, None].repeat(1, attention_mask.shape[-1]), torch.full_like(attention_mask, False), attention_mask) |
|
|
|
losses = Loss( |
|
loss=loss_dict["loss"], |
|
img_loss=loss_dict.get("img_loss", 0), |
|
txt_loss=loss_dict.get("txt_loss", 0), |
|
nlls=std_nlls, |
|
txt_nlls=loss_dict.get("std_txt_loss", 0), |
|
img_nlls=loss_dict.get("std_img_loss", 0), |
|
token_mask=attention_mask, |
|
modality_mask=modality_mask, |
|
extra_losses=loss_dict.get("extra_losses", None), |
|
) |
|
|
|
if getattr(self.config.trainer, "disable_torchmetrics", False): |
|
raise NotImplementedError("Torchmetrics disabled") |
|
|
|
elif prefix == "train": |
|
return losses |
|
elif prefix == "val": |
|
self.valid_metrics.update(losses.nlls, losses.token_mask) |
|
if hasattr(self, "valid_txt_metrics"): |
|
self.valid_txt_metrics.update(losses.txt_nlls, losses.modality_mask[..., 0] & losses.token_mask) |
|
self.valid_img_metrics.update(losses.img_nlls, losses.modality_mask[..., 1] & losses.token_mask) |
|
|
|
elif prefix == "test": |
|
self.test_metrics.update(losses.nlls, losses.token_mask) |
|
metrics = self.test_metrics |
|
self.log_dict(metrics, on_step=False, on_epoch=True, sync_dist=True) |
|
else: |
|
raise ValueError(f"Invalid prefix: {prefix}") |
|
|
|
@torch.no_grad() |
|
def zero_shot_eval(self): |
|
dataloader = self.validation_dataloader |
|
total_batches = len(dataloader) |
|
rprint(f"Zero shot eval with {total_batches} batches with limit_val_batches: {self.config.trainer.limit_val_batches}") |
|
for idx, batch in tqdm(enumerate(dataloader), total=total_batches, desc="Zero shot eval validation steps", disable=not is_main_process()): |
|
if self.config.trainer.limit_val_batches is not None and idx >= self.config.trainer.limit_val_batches: |
|
break |
|
self.zero_shot_eval_step(batch, idx) |
|
|
|
self.zero_shot_eval_epoch_end() |
|
|
|
def validate(self, state: TrainingState): |
|
self.on_validation_epoch_start() |
|
|
|
if getattr(self.config.eval, "compute_val_metrics_standalone", False) and getattr(self.config.eval, "bypass_normal_validation", False): |
|
batch = next(iter(self.validation_dataloader)) |
|
self.on_validation_epoch_end(example_batch=batch) |
|
self.on_validation_epoch_cleanup() |
|
return |
|
|
|
total_len = 10 if self.config.data.iterable or self.config.data.webdataset_indexed else len(self.validation_dataloader) |
|
dprint(f"Validation batches: {total_len}") |
|
|
|
total_batches = ( |
|
self.config.trainer.limit_val_batches |
|
if (self.config.trainer.limit_val_batches is not None and self.fid_eval is False) |
|
else total_len |
|
) |
|
if getattr(self.config.eval, 'pplx_full_dataset', False): |
|
rprint("[INFO] PPLX full dataset eval, setting total_batches to total_len") |
|
total_batches = total_len |
|
elif self.config.eval.max_num_fid_batches_per_device is not None and self.fid_eval: |
|
total_batches = min(total_len, self.config.eval.max_num_fid_batches_per_device) |
|
|
|
_dataloader = self.train_dataloader if self.config.eval.val_with_train_data else self.validation_dataloader |
|
rprint(f"Validating with {total_batches} batches on {self.world_size} GPUs with batch size {self.config.loader.eval_batch_size}") |
|
for idx, batch in tqdm(enumerate(_dataloader), total=total_batches, desc="Validation steps", disable=not is_main_process()): |
|
if self.config.trainer.limit_val_batches is not None and idx >= total_batches: |
|
break |
|
self.validation_step(batch, idx) |
|
|
|
if getattr(self.config.eval, "eval_large_batch", None) is not None: |
|
assert isinstance(batch, TensorDict) |
|
dataloader_iter = iter(_dataloader) |
|
large_batch = [next(dataloader_iter, None) for _ in range(getattr(self.config.eval, "eval_large_batch", None))] |
|
large_batch = [b for b in large_batch if b is not None] |
|
large_batch = torch.stack(large_batch, dim=0) |
|
batch = large_batch |
|
gprint(f"Large batch shape: {batch.shape}") |
|
else: |
|
batch = next(iter(_dataloader)) |
|
|
|
if self.config.eval.visualize_data_only: |
|
return |
|
|
|
if self.config.eval.compute_standalone_mauve and not getattr(self.config.eval, "global_disable_mauve", False): |
|
self.mauve_store_references(_dataloader) |
|
|
|
if self.config.mode == "eval": |
|
gprint(f"Batch shape: {batch['input_ids'].shape}") |
|
|
|
self.on_validation_epoch_end(example_batch=batch) |
|
self.on_validation_epoch_cleanup() |
|
|
|
@cached_property |
|
def global_batch_size(self): |
|
"""Batch size for a single step over all GPUs""" |
|
|
|
return self.step_batch_size * (1 if (self.config.trainer.xla_spmd and is_xla_available) else self.world_size) |
|
|
|
@cached_property |
|
def step_batch_size(self): |
|
"""Batch size for a single step for a single GPU""" |
|
return self.config.loader.batch_size * self.config.trainer.accumulate_grad_batches |
|
|
|
@cached_property |
|
def world_size(self): |
|
"""Number of GPUs over all nodes""" |
|
return get_world_size() |
|
|
|
@cached_property |
|
def num_tokens_per_sample(self): |
|
"""Number of tokens per sample""" |
|
return self.config.model.length |
|
|
|
@cached_property |
|
def gradient_accumulation_steps(self): |
|
"""Number of gradient accumulation steps""" |
|
return self.config.trainer.accumulate_grad_batches |
|
|
|
@cached_property |
|
def static_txt_sl(self): |
|
return slice(None, self.config.model.txt_length) |
|
|
|
@cached_property |
|
def static_img_sl(self): |
|
return slice(-self.config.model.img_length, None) |
|
|
|
def img_txt_pair_batch_mask(self, batch=None): |
|
return batch["modality_mask"][..., 1].sum(dim=-1) > 0 |
|
|
|
def txt_sl(self, batch=None): |
|
return batch["modality_mask"][..., 0] |
|
|
|
def img_sl(self, batch=None): |
|
return batch["modality_mask"][..., 1] |
|
|
|
@cached_property |
|
def is_compiled(self): |
|
return is_xla_available or self.config.trainer.compile |
|
|
|
@property |
|
def allow_slicing(self): |
|
return not is_xla_available and not self.backbone.training |
|
|
|
@property |
|
def training(self): |
|
return self.backbone.training |
|
|
|
def get_step_metrics(self): |
|
return { |
|
"trainer/global_step": self.global_step, |
|
"global_samples": self.global_step * self.global_batch_size, |
|
"train_metrics/global_tokens": self.global_step * self.global_batch_size * self.config.model.length, |
|
"effective_global_tokens": self.global_step * self.global_batch_size * self.config.model.length * (0.5 if self.config.parameterization == "subs" else 1.0), |
|
"effective_global_step": int(self.global_step * (0.5 if self.config.parameterization == "subs" else 1.0)), |
|
} |
|
|
|
def train(self): |
|
tr = self.config.trainer |
|
total_batch_size = self.global_batch_size |
|
initial_global_step = self.global_step |
|
true_step = 0 |
|
first_epoch = 0 |
|
self.current_run_global_step = 0 |
|
self.current_run_fwd_bwd_pass = 0 |
|
rprint(f"Started at step {self.accelerator.step}") |
|
if self.non_embedding_params < 1e9: |
|
with try_except(write_error_to_file=True, clear_cuda_cache=True): |
|
self.print_hashes() |
|
|
|
|
|
|
|
if is_torch_cuda_available(): |
|
dprint(f"Gathering step from {self.world_size} ranks") |
|
starting_steps = gather_object([self.accelerator.step]) |
|
rprint(f"Starting steps: {starting_steps}") |
|
if not all([x > 0 for x in starting_steps]): |
|
rprint(f"Not all ranks have >0 step, setting to: {starting_steps[0]}") |
|
self.accelerator.step = starting_steps[0] |
|
|
|
if is_xla_available: |
|
import torch_xla.core.xla_model as xm |
|
import torch_xla.debug.profiler as xp |
|
assert (self.config.trainer.accumulate_grad_batches == 1) or getattr(self.config.trainer, "allow_accum_grad_batches_xla", False), "Accumulate grad batches must be 1 for XLA" |
|
|
|
rprint(f"***** Starting training at global step: {self.global_step} *****") |
|
rprint(f" Instantaneous batch size per device = {self.config.loader.batch_size}") |
|
rprint(f" Gradient Accumulation steps = {tr.accumulate_grad_batches}") |
|
rprint(f" Num GPUs = {tr.devices}") |
|
rprint(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") |
|
rprint(f" Total optimization steps = {tr.max_steps}") |
|
rprint(f" Reported Global Batch Size: {self.global_batch_size}, Reported Step Batch Size: {self.step_batch_size}, Reported World Size: {self.world_size}") |
|
|
|
if not self.config.data.iterable and not self.config.data.webdataset_indexed and is_torch_cuda_available(): |
|
num_epoch_steps = len(self.train_dataloader) |
|
rprint(f" Num examples = {len(self.train_dataloader.dataset)}") |
|
rprint(f" Num batches each epoch = {len(self.train_dataloader)}") |
|
rprint(f"Train Dataloader Size on single GPU: {num_epoch_steps}") |
|
if len(self.train_dataloader.dataset) < total_batch_size: |
|
rprint("The training dataloader is smaller than the total batch size. This may lead to unexpected behaviour.") |
|
else: |
|
num_epoch_steps = 10000 |
|
|
|
if self.config.trainer.pytorch_profile: |
|
profiler = Profiler( |
|
output_dir=self.config.output_dir, warmup_steps=tr.profiler_warmup_steps, active_steps=tr.profiler_active_steps, record_memory=True |
|
) |
|
|
|
if self.config.trainer.viz_images_only: |
|
return self.viz_images_from_dataloader() |
|
|
|
progress_bar = tqdm(range(0, tr.max_steps), initial=initial_global_step, desc="Steps", disable=not is_local_main_process(), leave=False, smoothing=0.15) |
|
|
|
global_step_metrics = defaultdict(float) |
|
global_extra_wandb_metrics = dict() |
|
accumulate_steps = 0 |
|
first_start_time = time.time() |
|
self.on_train_start() |
|
|
|
rprint(f"Training for {tr.num_epochs} epochs...") |
|
last_end_step_time = start_timing(f"Dataloading accum:{accumulate_steps}, #{true_step}, global_step:{self.global_step}") |
|
for epoch in range(first_epoch, tr.num_epochs): |
|
rprint(f"Starting epoch {epoch}...") |
|
for step, batch in enumerate(self.train_dataloader): |
|
ddprint(f"Data Step: {step}") |
|
if self.config.trainer.iterate_dataloader_only: |
|
rprint(f"Iterating dataloader only: {step}") |
|
|
|
if (batch["attention_mask"] == 0).all(dim=-1).any(): |
|
breakpoint() |
|
batch = self.update_batch(batch) |
|
if (batch["sample_ids"] == -1).all(dim=-1).any(): |
|
breakpoint() |
|
continue |
|
|
|
elif getattr(self.config.trainer, "iterate_dataloader_n_dataloader_batches", None) is not None and step <= self.config.trainer.iterate_dataloader_n_dataloader_batches: |
|
self.current_run_fwd_bwd_pass += 1 |
|
if self.current_run_fwd_bwd_pass % self.config.trainer.accumulate_grad_batches == 0: |
|
self.global_step += 1 |
|
self.current_run_global_step += 1 |
|
ddprint(f"Iterating dataloader only for {self.config.trainer.iterate_dataloader_n_dataloader_batches} dataloader batches. At step {self.global_step=}, {self.current_run_global_step=}, {self.current_run_fwd_bwd_pass=}") |
|
continue |
|
|
|
if self.config.trainer.tpu_force_mark_step: xm.mark_step() |
|
if self.config.trainer.sync_dataloader_timing: synchronize_device() |
|
global_step_metrics[f"dataloading_time"] += end_timing(last_end_step_time) |
|
|
|
if self.config.trainer.nvtx_profile and self.is_compiled and step == 4: |
|
torch.cuda.cudart().cudaProfilerStart() |
|
|
|
if self.current_run_global_step == 1 and is_xla_available: |
|
gprint(f"First start time: {time.time() - first_start_time}") |
|
|
|
if getattr(self.config.data, "force_dummy_tensordict", False): |
|
gprint(self.global_step, self.current_run_global_step, true_step, batch["idx"].tolist(), batch["dataset_idx"].tolist()) |
|
|
|
if getattr(self.config.trainer, "assert_at_n_steps", None) is not None and self.global_step == self.config.trainer.assert_at_n_steps: |
|
gprint(batch["img_input_ids"].min(), batch["img_input_ids"].max(), batch["txt_input_ids"].min(), batch["txt_input_ids"].max()) |
|
|
|
if batch is None: |
|
rprint(f"Batch is None at step {step}") |
|
continue |
|
|
|
if self.config.trainer.tpu_force_mark_step: xm.mark_step() |
|
ddprint(f"After Data Step 2: {step}") |
|
with nullcontext() if is_xla_available else self.accelerator.accumulate(self.backbone): |
|
ddprint(f"Before forward pass for global_step: {self.global_step}") |
|
start_forward_time = start_timing(f"Forward Pass accum:{accumulate_steps}, #{true_step}, global_step:{self.global_step}") |
|
global_step_metrics["examples_seen_per_gpu"] += len(next(iter(batch.values()))) |
|
state: TrainingState = TrainingState( |
|
epoch_step=step, |
|
num_epoch_steps=num_epoch_steps, |
|
global_step=self.global_step, |
|
epoch=epoch, |
|
true_step=true_step, |
|
current_run_global_step=self.current_run_global_step, |
|
) |
|
|
|
if self.accelerator.sync_gradients and is_xla_available is False: |
|
self.cb_handler.on_train_step_start(state=state, unit=None) |
|
|
|
if self.config.trainer.tpu_force_mark_step: xm.mark_step() |
|
|
|
ddprint(f"Before Fwd: {step}") |
|
with xp.StepTrace('Forward', step_num=step) if self.config.trainer.tpu_profile else nullcontext(): |
|
losses = self.training_step(batch, step) |
|
|
|
ddprint(f"After Fwd: {step}") |
|
global_step_metrics["forward_pass_time"] += end_timing(start_forward_time) |
|
true_step += 1 |
|
evaluate_extra_log_data = lambda: dict() |
|
|
|
if self.config.trainer.tpu_force_mark_step: xm.mark_step() |
|
|
|
if isinstance(losses, dict): |
|
for k, v in losses.items(): |
|
if isinstance(v, torch.Tensor): |
|
global_step_metrics[k.removeprefix("metric_")] += v.detach().cpu().item() |
|
else: |
|
global_extra_wandb_metrics[k.removeprefix("metric_")] = v |
|
losses = dict( |
|
filter(lambda item: not item[0].startswith("metric_"), losses.items()) |
|
) |
|
loss = sum(losses.values()) |
|
elif isinstance(losses, Loss): |
|
loss = losses.loss |
|
metrics = self.train_metrics(losses.nlls, losses.token_mask) |
|
if hasattr(self, "txt_metrics") and losses.modality_mask is not None: |
|
txt_metrics = self.txt_metrics(losses.txt_nlls, losses.modality_mask[..., 0] & losses.token_mask) |
|
if hasattr(self, "img_metrics") and losses.modality_mask is not None: |
|
img_metrics = self.img_metrics(losses.img_nlls, losses.modality_mask[..., 1] & losses.token_mask) |
|
|
|
extra_losses_dict = losses.extra_losses |
|
extra_losses_dict = extra_losses_dict if extra_losses_dict is not None else dict() |
|
if self.config.trainer.tpu_force_mark_step: xm.mark_step() |
|
|
|
def evaluate_extra_log_data(): |
|
if hasattr(self, "txt_metrics"): |
|
return { |
|
**{f"train/txt_{k.split('/')[-1]}": v for k, v in replace_nan_dict(txt_metrics).items()}, |
|
**{f"train/img_{k.split('/')[-1]}": v for k, v in replace_nan_dict(img_metrics).items()}, |
|
} |
|
else: |
|
return {} |
|
|
|
ddprint(f"Before loss: {step}") |
|
incremental_dict_update(global_extra_wandb_metrics, { |
|
"trainer/loss": loss, |
|
"trainer/img_loss": losses.img_loss, |
|
"trainer/txt_loss": losses.txt_loss, |
|
**{ |
|
"global_samples": self.global_step * self.global_batch_size, |
|
"train_metrics/global_tokens": self.global_step * self.global_batch_size * self.config.model.length, |
|
"effective_global_tokens": self.global_step * self.global_batch_size * self.config.model.length * (0.5 if self.config.parameterization == "subs" else 1.0), |
|
"effective_global_step": int(self.global_step * (0.5 if self.config.parameterization == "subs" else 1.0)), |
|
}, |
|
**metrics, |
|
**extra_losses_dict, |
|
}) |
|
if self.config.trainer.tpu_force_mark_step: xm.mark_step() |
|
else: |
|
loss = losses |
|
|
|
if is_torch_cuda_available(): |
|
global_step_metrics["loss"] = loss.detach().cpu().item() |
|
|
|
ddprint(f"Before backward pass for global_step: {self.global_step}") |
|
|
|
|
|
if tr.backward_pass and (is_xla_available or torch.isfinite(loss).all()): |
|
start_backward_time = start_timing(f"Backward Pass accum:{accumulate_steps}, #{true_step}, global_step:{self.global_step}") |
|
if self.accelerator.sync_gradients: |
|
start_sync_time = start_timing(f"Gradient Sync global_step:{self.global_step}") |
|
if getattr(self.config.trainer, "sync_timing", False): |
|
sync_times(self.device) |
|
|
|
if self.config.trainer.tpu_force_mark_step: xm.mark_step() |
|
|
|
|
|
|
|
with xp.StepTrace('Backward', step_num=step) if self.config.trainer.tpu_profile else nullcontext(): |
|
ddprint(f"Before accelerator.backward for global_step: {self.global_step}") |
|
self.accelerator.backward(loss) |
|
ddprint(f"After accelerator.backward for global_step: {self.global_step}") |
|
|
|
with xp.StepTrace('After Backward + Clip', step_num=step) if self.config.trainer.tpu_profile else nullcontext(): |
|
if self.accelerator.sync_gradients: |
|
ddprint(f"Before after.backward for global_step: {self.global_step}") |
|
self.after_backward(state) |
|
if tr.gradient_clip_val is not None: |
|
ddprint(f"Before self.accelerator.clip_grad_norm_ for global_step: {self.global_step}") |
|
total_grad_norm = self.accelerator.clip_grad_norm_(self.backbone.parameters(), tr.gradient_clip_val) |
|
ddprint(f"After self.accelerator.clip_grad_norm_ for global_step: {self.global_step}") |
|
|
|
with xp.StepTrace('Optimizer + Scheduler Step', step_num=step) if self.config.trainer.tpu_profile else nullcontext(): |
|
ddprint(f"Before optimizer step for global_step: {self.global_step}, {step}") |
|
if is_xla_available and False: |
|
|
|
xm.optimizer_step(self.optimizer) |
|
else: |
|
self.optimizer.step() |
|
ddprint(f"After optimizer step for global_step: {self.global_step}, {step}") |
|
self.lr_scheduler.step() |
|
ddprint(f"After lr_scheduler step for global_step: {self.global_step}, {step}") |
|
|
|
zero_grad_kwargs = dict() |
|
if "apex" not in self.config.trainer.optimizer_cls: |
|
zero_grad_kwargs["set_to_none"] = tr.set_grads_to_none |
|
|
|
ddprint(f"Before zero_grad for global_step: {self.global_step}, {step}") |
|
self.optimizer.zero_grad(**zero_grad_kwargs) |
|
ddprint(f"Zeroed gradients for global_step: {self.global_step}, {step}") |
|
|
|
if self.accelerator.sync_gradients: |
|
if self.ema is not None: |
|
if self.config.trainer.use_custom_ema: |
|
ema_update(self.unwrap_model(self.ema), self.unwrap_model(self.backbone), self.config.trainer.ema) |
|
else: |
|
self.ema.step(self.get_params()) |
|
global_step_metrics["gradient_sync_time"] += end_timing(start_sync_time) |
|
|
|
global_step_metrics["backward_pass_time"] += end_timing(start_backward_time) |
|
else: |
|
if not torch.isfinite(loss).all(): gprint(f"Loss is not finite: {loss}") |
|
gprint("Skipping backward pass!") |
|
|
|
accumulate_steps += 1 |
|
self.current_run_fwd_bwd_pass += 1 |
|
|
|
|
|
|
|
ddprint(f"Syncing gradients for global_step: {self.global_step}. Should sync: {self.accelerator.sync_gradients}, {step}, {self.accelerator.step}, {self.accelerator.gradient_accumulation_steps}") |
|
if self.accelerator.sync_gradients: |
|
start_gradient_sync_time = start_timing(f"On Sync Gradients global_step:{self.global_step}, {step}") |
|
|
|
ddprint(f"Before on_train_step_end for global_step: {self.global_step}, {step}, {self.accelerator.step}, {self.accelerator.gradient_accumulation_steps}") |
|
state.batch = batch |
|
del loss, losses, batch |
|
gradient_sync_time_after_train_step_end_time = start_timing(f"On Sync Gradients global_step:{self.global_step}, {step}") |
|
self.on_train_step_end(state) |
|
ddprint(f"After on_train_step_end for global_step: {self.global_step}, {step}, {self.accelerator.step}, {self.accelerator.gradient_accumulation_steps}") |
|
global_step_metrics["gradient_sync_time_after_train_step_end"] += end_timing(gradient_sync_time_after_train_step_end_time) |
|
|
|
if is_xla_available and self.config.trainer.tpu_force_mark_step: xm.mark_step() |
|
|
|
if self.config.trainer.profile_memory and self.global_step + 1 >= tr.max_steps: |
|
rprint("Finished profiling memory...") |
|
break |
|
|
|
if self.config.trainer.pytorch_profile and profiler.step(self.global_step): |
|
rprint(f"Profiling finished at step: {self.global_step}") |
|
break |
|
|
|
if getattr(self.config.trainer, "throw_failure_for_testing", False) and self.current_run_global_step == 5: |
|
raise RuntimeError("Test failure") |
|
|
|
if is_xla_available and self.config.trainer.tpu_force_mark_step: xm.mark_step() |
|
|
|
progress_bar.update(1) |
|
self.global_step += 1 |
|
self.current_run_global_step += 1 |
|
global_step_metrics["gradient_sync_time"] += end_timing(start_gradient_sync_time) |
|
|
|
logs = { |
|
"examples_seen": self.global_step * total_batch_size, |
|
"trainer/global_step": self.global_step, |
|
**{k:v for k, v in global_step_metrics.items()}, |
|
**{f"lr_{i}": lr for i, lr in enumerate(self.lr_scheduler.get_last_lr())}, |
|
**global_extra_wandb_metrics, |
|
} |
|
|
|
if is_torch_cuda_available(): |
|
logs["gpu_max_mem_reserved_gb"] = torch.cuda.max_memory_reserved() / (1024**3) |
|
logs["gpu_cur_mem_reserved_gb"] = torch.cuda.memory_reserved() / (1024**3) |
|
logs["gpu_max_mem_allocated_gb"] = torch.cuda.max_memory_allocated() / (1024**3) |
|
logs["gpu_cur_mem_allocated_gb"] = torch.cuda.memory_allocated() / (1024**3) |
|
|
|
if is_xla_available: |
|
if self.global_step % getattr(self.config.trainer, "log_every_n_steps", 1) == 0: |
|
xm.add_step_closure(update_logs, args=(logs, evaluate_extra_log_data), run_async=False) |
|
del logs |
|
global_extra_wandb_metrics = dict() |
|
if self.config.trainer.tpu_force_mark_step: xm.mark_step() |
|
else: |
|
logs.update(evaluate_extra_log_data()) |
|
progress_bar.set_postfix(**logs) |
|
log(logs) |
|
global_extra_wandb_metrics = dict() |
|
|
|
|
|
if getattr(self.config.trainer, "sync_timing", False): |
|
global_step_metrics = {f"rank_{get_rank()}/{k}": v for k, v in global_step_metrics.items()} |
|
all_step_metrics = self.accelerator.gather_for_metrics([global_step_metrics], use_gather_object=True) |
|
merged_metrics = {k: v for d in all_step_metrics for k, v in d.items()} |
|
log(merged_metrics) |
|
|
|
if is_xla_available and self.config.trainer.tpu_force_mark_step: xm.mark_step() |
|
|
|
global_step_metrics = defaultdict(float) |
|
accumulate_steps = 0 |
|
|
|
if self.global_step >= tr.max_steps: |
|
break |
|
|
|
ddprint(f"After logging for step v3: {self.global_step}, {step}") |
|
|
|
if getattr(self.config.trainer, "assert_at_n_steps", None) is not None and self.current_run_global_step >= getattr(self.config.trainer, "assert_at_n_steps", None): |
|
raise RuntimeError(f"Assertion failed at step {self.current_run_global_step}") |
|
|
|
ddprint(f"After logging for step v4: {self.global_step}, {step}") |
|
|
|
if is_xla_available and self.config.trainer.tpu_profile and (self.global_step == 0 or self.global_step % 50 == 0) and is_main_process(): |
|
import torch_xla.debug.metrics as met |
|
rprint(met.metrics_report()) |
|
met.clear_all() |
|
|
|
if is_xla_available and self.config.trainer.tpu_force_mark_step: xm.mark_step() |
|
ddprint(f"Finished sync_gradients: {self.global_step}, {self.accelerator.step}, {self.accelerator.gradient_accumulation_steps}") |
|
|
|
ddprint(f"Finished step: {self.global_step},{step},{self.accelerator.step},{self.accelerator.gradient_accumulation_steps},{self.accelerator.gradient_state.__repr__()}") |
|
if self.config.trainer.sync_dataloader_timing: synchronize_device() |
|
last_end_step_time = start_timing(f"Dataloading #{true_step + 1}") |
|
|
|
if self.global_step >= tr.max_steps: |
|
break |
|
|
|
dprint(f"Finished epoch: {epoch}") |
|
|
|
|
|
rprint("Training finished.") |
|
barrier() |
|
|
|
if tr.profile_memory: |
|
print_memory(verbose=True) |
|
save_memory_profile(self.config.output_dir / "profile") |
|
|
|
if tr.pytorch_profile: |
|
profiler.finish() |
|
elif tr.nvtx_profile: |
|
torch.cuda.cudart().cudaProfilerStop() |
|
elif self.global_step > 100 or tr.skip_early_checkpointing is False: |
|
self.checkpoint(state) |
|
|
|
barrier() |
|
|