unidisc / model.py
aswerdlow's picture
Initial commit
131da64
import math
import random
import types
import time
from collections import defaultdict
from contextlib import nullcontext
from functools import cached_property, partial
from contextlib import ExitStack
from numpy import mask_indices
from unidisc.utils.tensor_utils import get_contiguous_blocks, get_contiguous_blocks_per_sample, get_interleaved_indices
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from accelerate.utils import gather, gather_object
from einops import rearrange
from tensordict import TensorDict
from torch import Tensor, nn
from tqdm.auto import tqdm
import model_eval
import model_setup
import model_utils
import utils
from decoupled_utils import (Profiler, barrier, dprint, get_rank, get_world_size, gprint,
is_local_main_process, is_main_process,
is_torch_cuda_available, is_torch_xla_available,
print_memory, rprint, save_memory_profile,
synchronize_device, try_except, use_dist)
from unidisc.tokenizers.image_tokenizers import (decode_latents, get_image_batch,
get_vae, vae_encode_image)
from unidisc.utils.cuda_utils import sync_times
from unidisc.utils.xla_utils import shard_output
from model_utils import (Loss, ddprint, ema_update, empty_device_cache, get_chameleon_txt_indices, get_interleaved_block_mask, log,
replace_nan_dict, update_histogram, update_logs, get_block_mask)
from unidisc.utils.trainer_utils import TrainingState, incremental_dict_update, linear_warmup
is_xla_available = is_torch_xla_available()
if is_xla_available:
import torch_xla
from torch_xla.distributed.spmd import XLAShardedTensor
def maybe_unwrap(t: torch.Tensor) -> torch.Tensor:
return t.global_tensor if isinstance(t, XLAShardedTensor) else t
class Diffusion:
def __init__(self, config, tokenizer, device, disable_init=False):
super().__init__()
setup_methods = [
'init', 'to', 'get_params', 'get_vae', 'get_cond_vae', 'configure_optimizers',
'_validate_configuration', 'register_signal_handler', 'on_train_start',
'optimizer_step', 'init_dataloader', 'set_accelerator', 'set_callbacks',
'on_train_step_end', 'init_optimizer_lr_scheduler', 'after_backward', 'checkpoint',
'print_hashes', 'shortcut_return', 'reset_validation_metrics', 'unwrap_model'
]
for method_name in setup_methods:
setattr(self, method_name, types.MethodType(getattr(model_setup, method_name), self))
utils_methods = [
'get_coord_plot', '_score_entropy', 'sample_subs_guidance',
'restore_model_and_semi_ar_sample', '_reconstruction_loss',
'restore_model_and_sample', 'get_score', '_staggered_score',
'_analytic_update', '_denoiser_update', '_transp_transition',
'eval_retokenize', 'compute_generative_perplexity', '_d3pm_loss',
'_d3pm_parameterization', '_sedd_parameterization',
'get_base_shapes_for_mup', 'update_histogram', '_maybe_sub_sample',
'viz_images_from_dataloader', 'compute_cider'
]
for method_name in utils_methods:
setattr(self, method_name, types.MethodType(getattr(model_utils, method_name), self))
eval_methods = [
'get_every_n_evals', 'on_validation_epoch_start', 'sample',
'predict_step', 'validation_step', 'on_validation_epoch_end',
'on_validation_epoch_cleanup', '_sample_prior', '_ddpm_forward',
'_ddpm_update', '_ddpm_caching_update', '_sample', '_ar_sampler',
'decode_batch', 'sample_transfusion', 'sample_continuous_image',
'decode_sampling', '_ddpm_update_finetune_controlled_tweedie',
'sample_masking', 'log_flops', "visualize_samples", "_maskgit_update",
"_first_hitting_update", "update_inline_fid", "compute_inline_fid",
"update_clean_fid", "compute_clean_fid_eval", "sample_for_fid",
"compute_clip_score", "mauve_store_references", "zero_shot_eval_step",
"zero_shot_eval_epoch_end", "get_cfg_weight", "cleanup_fid_output",
"calculate_chameleon_perplexity", "get_anole_data",
"update_img_to_txt_mauve_clip", "compute_mauve_entropy",
"get_top_k", "compute_entropy", "get_mauve_score", "get_valid_seq", "gather_tokens",
"count_valid_tokens", "compute_val_metrics_standalone", "_maskgit_nucleus_update",
"get_img_text_saturation_batch", "handle_interleaved_decode", "get_interleaved_image",
"auto_enhance", "get_clip_score", "get_dfn_score", "get_hpsv2_score", "get_model_likelihood_score",
"get_laion_aesthetic_score", "get_rewards", "get_chameleon_score", "clear_reward_models",
"get_text_likelihood_score", "get_text_reward_model_score", "save_image_text_pair"
]
for method_name in eval_methods:
setattr(self, method_name, types.MethodType(getattr(model_eval, method_name), self))
if disable_init:
pass
else:
model_setup.init(self, config, tokenizer, device)
@cached_property
def xla_mesh(self):
import torch_xla.distributed.spmd as xs
return xs.get_global_mesh()
def on_train_resume(self):
if not is_torch_xla_available():
empty_device_cache()
if self.ema is not None and not self.config.trainer.use_custom_ema:
self.ema.restore(self.get_params(), raise_error_if_already_restored=False)
self.backbone.train()
def zero_shot_update_batch(self, batch):
dataset = self.config.data.train
if dataset is None:
return batch
def get_attr(attr_name):
return getattr(self.config.model, attr_name, None)
if dataset == "nlphuji/flickr30k":
# image captioning dataset
# above thing but order is [txt, img]
batch['gt_input_ids'] = batch['input_ids']
image_input_ids = get_image_batch(self.config, self.get_vae(), batch, self.device)
image_input_ids += self.text_vocab_size
batch["input_ids"] = torch.cat([torch.zeros_like(batch['gt_input_ids'], dtype=torch.int64), image_input_ids], dim=-1).to(self.device)
batch['attention_mask'] = torch.cat([torch.zeros_like(batch['gt_input_ids'], dtype=torch.bool), torch.ones_like(image_input_ids, dtype=torch.bool)], dim=-1).to(self.device)
batch["modality"] = torch.cat([torch.zeros_like(batch['gt_input_ids'], dtype=torch.int64), torch.ones_like(image_input_ids, dtype=torch.int64)], dim=-1).to(self.device)
elif dataset == "facebook/winoground":
# get image and text input ids
caption_0_input_ids = batch['caption_0_input_ids']
caption_1_input_ids = batch['caption_1_input_ids']
image_0 = batch['img_0']
image_1 = batch['img_1']
# tokenize and store captions separately
image_0_input_ids = vae_encode_image(self.config, self.get_vae(), image_0, self.device, get_attr("vae_type")) + self.text_vocab_size
image_1_input_ids = vae_encode_image(self.config, self.get_vae(), image_1, self.device, get_attr("vae_type")) + self.text_vocab_size
# make 4 combinat ions of image and text
batch['input_ids_0_0'] = torch.cat([caption_0_input_ids, image_0_input_ids], dim=-1).to(self.device)
batch['input_ids_0_1'] = torch.cat([caption_0_input_ids, image_1_input_ids], dim=-1).to(self.device)
batch['input_ids_1_0'] = torch.cat([caption_1_input_ids, image_0_input_ids], dim=-1).to(self.device)
batch['input_ids_1_1'] = torch.cat([caption_1_input_ids, image_1_input_ids], dim=-1).to(self.device)
batch['attention_mask'] = torch.cat([torch.zeros_like(caption_0_input_ids, dtype=torch.bool), torch.ones_like(image_0_input_ids, dtype=torch.bool)], dim=-1).to(self.device)
batch['modality'] = torch.cat([torch.zeros_like(caption_0_input_ids, dtype=torch.int64), torch.ones_like(image_0_input_ids, dtype=torch.int64)], dim=-1).to(self.device)
# elif dataset == "facebook/winoground":
batch["modality_mask"] = F.one_hot(batch["modality"], num_classes=2).to(torch.bool)
return batch
def update_batch(self, batch):
if getattr(self.config.eval, 'big_seq_len_eval', False):
# new batch of 8192 seq length with txt length 4096 and img length 4096s
N = self.config.model.length
new_batch = dict()
new_batch['input_ids'] = torch.zeros(batch['input_ids'].shape[0], N, device=self.device, dtype=batch['input_ids'].dtype)
new_batch['attention_mask'] = torch.ones(batch['attention_mask'].shape[0], N, device=self.device, dtype=batch['attention_mask'].dtype)
new_batch['modality'] = torch.zeros(batch['modality'].shape[0], N, device=self.device, dtype=batch['modality'].dtype)
new_batch['modality'][:, N//2:] = 1
new_batch['modality_mask'] = F.one_hot(new_batch['modality'], num_classes=2).to(torch.bool)
batch = new_batch
return batch
continuous_mode = self.config.trainer.image_mode == "continuous"
if batch is None:
gprint(f"Warning! Batch is None")
return batch
if isinstance(batch, TensorDict):
batch.batch_size = (batch.batch_size[0],)
if self.image_model or getattr(self.config.data, "force_image_dataset", False):
text_input_ids = None
if isinstance(batch, TensorDict) and (self.is_compiled or getattr(self.config.trainer, "force_convert_to_dict", False)):
batch = dict(batch.items())
if "txt_input_ids" in batch or "img_input_ids" in batch:
index_keys = ["img_input_ids", "txt_input_ids", "sample_ids"]
for key in index_keys:
if key in batch:
if isinstance(batch[key], list):
batch[key] = torch.stack(batch[key], dim=0)
batch[key] = batch[key].to(torch.int64)
index_keys = ["img_label"]
for key in index_keys:
if key in batch:
batch[key] = batch[key].squeeze(-1)
img_input_ids = batch.pop("img_input_ids")
batch["input_ids"] = img_input_ids
batch["attention_mask"] = torch.ones_like(img_input_ids).to(torch.bool)
if "txt_input_ids" in batch:
batch["input_ids"] = torch.cat([batch["txt_input_ids"], batch["input_ids"] + self.text_vocab_size], dim=-1)
batch["attention_mask"] = torch.cat([batch["txt_attention_mask"], batch["attention_mask"]], dim=-1)
batch["input_ids"] = batch["input_ids"].to(torch.int64)
if "modality" not in batch:
if getattr(self.config.trainer, "ignore_text_in_unified", False):
modality = torch.ones_like(batch["input_ids"], dtype=torch.int64)
else:
assert self.config.model.txt_length > 0 and self.config.model.img_length > 0
modality = torch.zeros_like(batch["input_ids"], dtype=torch.int64)
modality[:, -img_input_ids.shape[-1]:] = 1
batch["modality"] = modality
elif (self.config.trainer.multimodal_batches or continuous_mode) and \
not getattr(self.config.trainer, "use_legacy_update_batch_fn", False):
if "img" in batch:
is_image_batch = (batch["modality"] == 1).all(dim=-1)
image_input_ids = get_image_batch(self.config, self.get_vae(), batch, self.device)
assert ((batch["modality"].sum(dim=-1) == 0) | (batch["modality"].sum(dim=-1) >= image_input_ids.shape[1])).all()
if getattr(self.config.trainer, "add_label", False):
assert (batch["modality"] == 1).all()
batch["input_ids"][:, 1:] = torch.where(is_image_batch[:, None], image_input_ids, batch["input_ids"][:, 1:])
elif image_input_ids.ndim == 3:
batch["img_emb"] = torch.where((batch["modality"] == 1)[:, :, None], image_input_ids, torch.nan)
elif (batch["input_ids"][batch["modality"] == 1] == -1).all():
batch["input_ids"].masked_scatter_(batch["modality"] == 1, image_input_ids)
else:
batch["input_ids"] = torch.where(is_image_batch[:, None], image_input_ids, batch["input_ids"])
if getattr(self.config.trainer, "force_shift_raw_image_batches", False):
assert not getattr(self.config.trainer, "force_shift_image_batches", False)
batch["input_ids"] = torch.where(batch["modality"] == 1, batch["input_ids"] + self.text_vocab_size, batch["input_ids"])
else:
if getattr(self.config.trainer, "add_label", False):
shift_index = self.vocab_size - self.config.model.add_labels
batch["input_ids"] = torch.cat([batch["label"] + shift_index, batch["input_ids"]], dim=-1)
batch["attention_mask"] = torch.cat([torch.zeros_like(batch["label"], dtype=torch.bool), batch["attention_mask"]], dim=-1)
batch["modality"] = torch.cat([torch.ones_like(batch["label"], dtype=torch.int64), batch["modality"]], dim=-1)
assert (batch["modality"] == 1).all()
batch["input_ids"] = batch["input_ids"].to(torch.int64)
if "sample_ids" in batch:
batch["sample_ids"] = batch["sample_ids"].to(torch.int64)
if getattr(self.config.trainer, "force_shift_image_batches", False):
batch["input_ids"] = torch.where(batch["modality"] == 1, batch["input_ids"] + self.text_vocab_size, batch["input_ids"])
else:
if continuous_mode:
assert False
else:
if "input_ids" in batch and not self.config.trainer.ignore_text_in_unified:
assert self.config.model.unified_model
assert "attention_mask" in batch
text_input_ids = batch["input_ids"]
image_ids = get_image_batch(self.config, self.get_vae(), batch, self.device)
image_attention_mask = torch.ones_like(image_ids).to(torch.bool)
if "cond_img" in batch:
cond_image_ids = get_image_batch(self.config, self.get_cond_vae(), batch, self.device, use_cond=True)
batch["cond_input_ids"] = cond_image_ids
if text_input_ids is not None:
assert batch["input_ids"].shape[1] == self.config.model.txt_length
assert image_ids.shape[1] == self.config.model.img_length
image_ids = image_ids + self.text_vocab_size
batch["input_ids"] = torch.cat([batch["input_ids"].to(self.device), image_ids], dim=-1)
batch["attention_mask"] = torch.cat([batch["attention_mask"].to(self.device), image_attention_mask], dim=-1).to(torch.bool)
assert batch["input_ids"].shape[1] == batch["attention_mask"].shape[1] == self.config.model.length
batch["modality"] = torch.zeros_like(batch["input_ids"], dtype=torch.int64)
batch["modality"][:, -image_ids.shape[-1]:] = 1
else:
assert self.unified_model is False
batch["input_ids"] = image_ids
batch["attention_mask"] = image_attention_mask
batch["modality"] = torch.ones_like(batch["input_ids"], dtype=torch.int64)
if "txt_x0_unmask" in batch and "img_x0_unmask" in batch:
assert not continuous_mode
batch["gt_img_input_ids"] = image_ids
batch["x0_unmask"] = torch.cat([batch["txt_x0_unmask"], batch["img_x0_unmask"]], dim=-1)
batch["input_ids"][~batch["x0_unmask"]] = self.mask_index
if (batch["input_ids"].shape[1] != self.config.model.length) and not self.config.trainer.ar_inpainting:
gprint(f"Warning! Input ids shape: {batch['input_ids'].shape}, model length: {self.config.model.length}")
batch["input_ids"] = batch["input_ids"][:, : self.config.model.length]
assert False, f"input ids are not the correct length input ids shape: {batch['input_ids'].shape}, model length: {self.config.model.length}"
if getattr(self.config.model, "img_cond", False):
assert "cond_input_ids" in batch
assert not continuous_mode
if "modality" in batch:
batch["modality"] = batch["modality"].to(torch.int64)
if self.config.trainer.multimodal_batches and batch["modality"].ndim == 2 and batch["modality"].shape[-1] == 1:
batch["modality"] = batch["modality"].repeat(1, self.config.model.length)
else:
if self.image_model and not self.config.trainer.multimodal_batches:
assert self.config.model.txt_length > 0 and self.config.model.img_length > 0
modality = torch.zeros_like(batch["input_ids"], dtype=torch.int64)
modality[:, self.static_img_sl] = 1
batch["modality"] = modality
elif self.config.data.txt_only:
batch["modality"] = torch.zeros_like(batch["input_ids"], dtype=torch.int64)
if "modality" in batch:
batch["modality"][batch["modality"] == -1] = 0
assert batch["modality"].min() == 0 and batch["modality"].max() == 1
batch["modality_mask"] = F.one_hot(batch["modality"], num_classes=2).to(torch.bool)
batch["batch_contains_img"] = (batch["modality"] == 1).any(dim=-1)
batch['txt_sl'] = self.txt_sl(batch)
batch['img_sl'] = self.img_sl(batch)
if getattr(self.config.trainer, "force_remove_img_tokens", False):
assert not continuous_mode
batch["input_ids"] = batch["input_ids"][batch['txt_sl']]
batch["attention_mask"] = batch["attention_mask"][batch['txt_sl']]
if getattr(self.config.trainer, "add_label", False):
assert getattr(self.config.model, "add_labels", False)
assert "label" in batch
batch["label"] = batch["label"].to(torch.int64)
assert 0 <= batch["label"].min() and batch["label"].max() < self.config.model.add_labels
shift_index = self.vocab_size - self.config.model.add_labels
assert batch["input_ids"].shape[-1] == self.config.model.length
if batch["label"].ndim == 1:
batch["input_ids"][:, [0]] = (batch["label"] + shift_index).unsqueeze(-1)
else:
batch["input_ids"][:, [0]] = batch["label"] + shift_index
batch["attention_mask"][:, 0] = False
if isinstance(batch, dict):
for key in batch.keys():
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].to(self.device)
elif isinstance(batch, TensorDict):
assert self.config.backbone != "gemma"
batch = batch.to(self.device)
if getattr(self.config.trainer, "force_full_attention_mask", False):
batch["attention_mask"] = torch.ones_like(batch["attention_mask"], dtype=torch.bool)
batch["attention_mask"] = batch["attention_mask"].to(torch.bool)
if self.config.data.require_sample_ids:
assert "sample_ids" in batch
batch["sample_ids"][~(batch["attention_mask"].bool())] = -1
batch["attention_mask"][batch["sample_ids"] == -1] = False
# Flip [txt, img] -> [img, txt]
# TODO: Flip by sample not batch. As we train w/~8 batches, it's for now
if (self.training or getattr(self.config.trainer, "force_flip_ar_val", False)) and self.config.parameterization == "ar" and getattr(self.config.trainer, "rand_flip_ar_prob", None) is not None:
assert (batch["modality"][:, :self.config.model.txt_length] == 0).all() and (batch["modality"][:, self.config.model.txt_length:] == 1).all(), "Modality does not match img_before_txt configuration"
batch_flip_mask = torch.rand(batch["modality"].shape[0], device=self.device) < self.config.trainer.rand_flip_ar_prob
img_slice = slice(-self.config.model.img_length, None)
txt_slice = slice(None, self.config.model.txt_length)
for key in ["modality", "attention_mask", "input_ids"]:
batch[key][batch_flip_mask] = torch.cat([batch[key][batch_flip_mask][:, img_slice], batch[key][batch_flip_mask][:, txt_slice]], dim=1)
if "modality_mask" in batch:
batch["modality_mask"] = F.one_hot(batch["modality"], num_classes=2).to(torch.bool)
batch['txt_sl'] = None
batch['img_sl'] = None
batch["batch_flip_mask"] = batch_flip_mask
if self.config.trainer.interleaved and "sample_ids" not in batch:
batch["sample_ids"] = torch.zeros_like(batch["modality"], dtype=torch.int64)
if self.config.trainer.interleaved:
batch_indices, start_positions, end_positions = get_contiguous_blocks(batch["modality"])
interleaved_metadata = TensorDict({
"batch_indices": batch_indices,
"start_positions": start_positions,
"end_positions": end_positions
}, batch_size=[])
allowed_image_sizes = (64, 256, 1024, 2304, 4096)
block_sizes = (end_positions - start_positions).to(torch.int32)
is_txt_block = batch["modality"][batch_indices, start_positions] == 0
is_valid_img_size = torch.isin(block_sizes, torch.tensor(allowed_image_sizes, dtype=torch.int32, device=self.device))
if not ((is_txt_block | is_valid_img_size).all()):
gprint(f"WARNING: Found non-text block of size {block_sizes[~(is_txt_block | is_valid_img_size)]} in interleaved batch")
if isinstance(batch, TensorDict):
batch.batch_size = []
batch["interleaved_metadata"] = interleaved_metadata
return batch
def get_cond_dict(self, batch):
ret_dict = dict()
if "cond_input_ids" in batch:
ret_dict["x_cond"] = batch["cond_input_ids"]
if "img_label" in batch:
ret_dict["label"] = batch["img_label"]
if self.config.model.use_attention_mask:
ret_dict["attention_mask"] = batch["attention_mask"]
if self.config.trainer.multimodal_batches:
ret_dict["modality"] = batch["modality"]
if self.config.trainer.image_mode == "continuous":
ret_dict["continuous_mode"] = True
ret_dict["modality"] = batch["modality"]
if self.parameterization == "ar" and "modality" in batch:
ret_dict["modality"] = batch["modality"]
return ret_dict
def training_step(self, batch, batch_idx):
batch = self.update_batch(batch)
return self.compute_loss(batch, prefix="train", batch_idx=batch_idx)
def q_xt(self, x, move_chance, allow_move_mask=None, return_ignore_batch_mask_for_metrics=False, mask_image_square=False, mask_text_region=False, batch=None):
"""Computes the noisy sample xt.
Args:
x: int torch.Tensor with shape (batch_size,
diffusion_model_input_length), input.
move_chance: float torch.Tensor with shape (batch_size, 1).
"""
if self.config.backbone == "maskdit" and getattr(self.config.trainer, "force_single_timestep_per_batch", False):
num_to_mask = int(x.shape[1] * move_chance[0].item())
batch_size, seq_len = x.shape
random_indices = torch.rand(batch_size, seq_len, device=x.device).argsort(dim=1)[:, :num_to_mask]
xt = x.scatter(1, random_indices, self.mask_index)
return xt
move_indices = torch.rand(*x.shape, device=x.device) < move_chance
if mask_image_square:
latent_dim = int(math.sqrt(self.config.model.img_length))
img_move_indices = move_indices[:, self.static_img_sl].clone().reshape(move_indices.shape[0], latent_dim, latent_dim)
max_d = int(math.sqrt(self.config.model.img_length))
for b in range(move_indices.shape[0]):
if move_chance[b] == 1:
continue
h, w = img_move_indices[b].shape
d = random.randint(max_d // 2, max_d - 2)
i = random.randint(0, h - d)
j = random.randint(0, w - d)
mask = torch.zeros_like(img_move_indices[b], dtype=torch.bool)
mask[i:i+d, j:j+d] = True
move_indices[b, self.static_img_sl] = mask.reshape(-1)
if mask_text_region:
for b in range(x.shape[0]):
if move_chance[b] == 1:
continue
should_mask = torch.zeros_like(move_indices[b, self.static_txt_sl], dtype=torch.bool)
max_valid = (x[b] == self.tokenizer.eos_token_id).nonzero()[0, 0] if self.tokenizer.eos_token_id in x[b] else x.shape[1]
d = random.randint(max_valid//3, max_valid-1)
start = random.randint(0, max_valid - d)
should_mask[start:start+d] = True
move_indices[b, self.static_txt_sl] = should_mask
ignore_batch_mask_for_metrics = None
should_mask_txt, should_mask_img = None, None
if (mask_prob := getattr(self.config.trainer, "mask_entire_modality", None)) is not None \
and (mask_image_square is False and mask_text_region is False) and self.backbone.training:
assert batch is not None
batch_size, seq_len = x.shape
if getattr(self.config.trainer, "mask_txt_only", False):
should_mask_txt = torch.rand(batch_size, 1, device=x.device) < mask_prob
should_mask_img = torch.zeros_like(should_mask_txt, device=x.device)
else:
should_mask_txt = torch.rand(batch_size, 1, device=x.device) < mask_prob/2
should_mask_img = torch.rand(batch_size, 1, device=x.device) < mask_prob/2
if self.config.trainer.multimodal_batches:
if self.config.trainer.interleaved:
batch_indices, start_positions, end_positions = get_contiguous_blocks_per_sample(batch["modality"], batch["sample_ids"])
block_size = end_positions - start_positions
size_mask = block_size > 4
batch_indices, start_positions, end_positions = batch_indices[size_mask], start_positions[size_mask], end_positions[size_mask]
block_counts = torch.zeros_like(batch_indices)
max_num_sample_ids = torch.zeros_like(batch_indices)
for i in range(len(batch_indices)):
curr_sample_id = batch["sample_ids"][batch_indices[i], start_positions[i]]
# Find blocks before this one with same batch index and sample_id
prev_blocks_mask = (batch_indices[:i] == batch_indices[i]) & \
(batch["sample_ids"][batch_indices[:i], start_positions[:i]] == curr_sample_id)
total_in_sample = ((batch_indices == batch_indices[i]) & (batch["sample_ids"][batch_indices, start_positions] == curr_sample_id)).sum()
block_counts[i] = prev_blocks_mask.sum()
max_num_sample_ids[i] = total_in_sample
block_prob = (block_counts + 1) / max_num_sample_ids
positions = torch.arange(move_indices.shape[-1], device=move_indices.device).unsqueeze(0) # Shape: [1, N]
mask = (positions >= start_positions.unsqueeze(1)) & (positions < end_positions.unsqueeze(1)) # Shape: [M, N]
mask = mask & (torch.rand(batch_indices.shape[0], 1, device=x.device) < (mask_prob * block_prob * 2)[..., None])
expanded_batch_indices = batch_indices.unsqueeze(1).expand(-1, move_indices.shape[1]) # Shape: [M, N]
# True if we should manually mask the part of the sequence
accum = torch.zeros_like(move_indices, dtype=torch.int32) # Shape: [B, N]
accum.scatter_add_(0, expanded_batch_indices, mask.int()) # Accumulate counts
accum = accum.to(torch.bool)
move_indices = move_indices | accum
# We ignore the entire sequence if any of the blocks are fully masked
ignore_batch_mask_for_metrics = torch.zeros((move_indices.shape[0],), device=x.device, dtype=torch.bool)
ignore_batch_mask_for_metrics.scatter_add_(0, batch_indices, mask.any(dim=-1))
else:
# TODO: Be smarter about masking for interleaved
# To make sure that we have even masking prob, we prefer to mask less but equally
both_mask = should_mask_txt & should_mask_img
should_mask_txt = torch.where(both_mask, False, should_mask_txt)
should_mask_img = torch.where(both_mask, False, should_mask_img)
move_indices = torch.where(should_mask_txt, batch["modality_mask"][..., 0], move_indices)
move_indices = torch.where(should_mask_img, batch["modality_mask"][..., 1], move_indices)
ignore_batch_mask_for_metrics = should_mask_img | should_mask_txt
else:
both_mask = should_mask_txt & should_mask_img
should_mask_txt[both_mask] = False
should_mask_img[both_mask] = False
should_mask_img[batch["txt_sl"].all(dim=-1)] = False
move_indices[:, self.static_txt_sl] = torch.where(should_mask_txt, True, move_indices[:, self.static_txt_sl])
move_indices[:, self.static_img_sl] = torch.where(should_mask_img, True, move_indices[:, self.static_img_sl])
ignore_batch_mask_for_metrics = should_mask_img | should_mask_txt
joint_ar_nar_mask = None
if self.config.trainer.joint_ar_nar_prob is not None and self.training:
batch_size = x.shape[0]
current_prob = linear_warmup(
current_step=self.global_step,
warmup_steps=self.config.trainer.joint_ar_nar_prob_warmup_steps,
final_value=self.config.trainer.joint_ar_nar_prob,
initial_value=1.0
)
joint_ar_nar_mask = torch.rand(batch_size, device=x.device) < current_prob
move_indices = torch.where(joint_ar_nar_mask[:, None], False, move_indices)
if self.config.trainer.add_label:
move_indices[:, 0] = False
if self.config.trainer.first_token_dropout is not None and self.training:
_initial_mask = torch.rand(x.shape[0], device=x.device) < self.config.trainer.first_token_dropout
move_indices[:, 0] = torch.where(_initial_mask, True, move_indices[:, 0])
if ignore_batch_mask_for_metrics is None:
ignore_batch_mask_for_metrics = _initial_mask
else:
ignore_batch_mask_for_metrics = ignore_batch_mask_for_metrics | _initial_mask
if allow_move_mask is not None:
move_indices = move_indices & allow_move_mask
if getattr(self.config.trainer, "discrete_diffusion_mode", "absorbing") == "uniform":
if getattr(self.config.model, "force_argmax_valid_indices", False):
assert self.mask_index == self.text_vocab_size - 1
text_random_tokens = torch.randint(0, self.text_vocab_size - 1, size=x.shape, device=x.device)
img_random_tokens = torch.randint(self.text_vocab_size, self.vocab_size, size=x.shape, device=x.device)
random_tokens = torch.where(batch["modality_mask"][..., 0], text_random_tokens, img_random_tokens)
assert not torch.any(random_tokens == self.mask_index)
else:
random_tokens = torch.randint(0, vocab_size, size=x.shape, device=x.device)
random_tokens = torch.where(random_tokens == self.mask_index, random_tokens + 1, random_tokens) # avoid mask index
xt = torch.where(move_indices, random_tokens, x)
else:
xt = torch.where(move_indices, self.mask_index, x)
if self.parameterization == "ar":
xt = x.clone()
if return_ignore_batch_mask_for_metrics:
return xt, ignore_batch_mask_for_metrics, joint_ar_nar_mask, should_mask_txt, should_mask_img, move_indices
else:
return xt
def _sample_t(self, n, device):
if self.config.backbone == "maskdit" and getattr(self.config.trainer, "force_single_timestep_per_batch", False):
_eps_t = torch.rand(1, device=device).repeat(n)
else:
_eps_t = torch.rand(n, device=device)
if self.config.trainer.joint_ar_nar_timestep_warmup_steps is not None:
max_t = linear_warmup(
current_step=self.global_step,
warmup_steps=self.config.trainer.joint_ar_nar_timestep_warmup_steps,
final_value=1,
initial_value=0,
start_step=0
)
_eps_t = _eps_t * max_t
if max_t == 1:
offset = torch.arange(n, device=device) / n
_eps_t = (_eps_t / n + offset) % 1
elif self.antithetic_sampling:
offset = torch.arange(n, device=device) / n
_eps_t = (_eps_t / n + offset) % 1
if getattr(self.config.trainer, "force_timestep", None) is not None:
_eps_t[:] = self.config.trainer.force_timestep
elif getattr(self.config.eval, "ar_inpainting_force_val", None) is not None:
_eps_t[:] = self.config.eval.ar_inpainting_force_val
t = (1 - self.sampling_eps) * _eps_t + self.sampling_eps
if self.importance_sampling:
return self.noise.importance_sampling_transformation(t)
return t.to(torch.float32)
def _subs_parameterization(self, logits, xt, batch=None, modality=None, **kwargs):
# log prob at the mask index = - infinity
if not self.allow_slicing:
logits = logits.clone()
logits[..., self.mask_index] += self.neg_infinity
if getattr(self.config.model, "force_argmax_valid_indices", False):
if self.config.trainer.multimodal_batches:
_txt_sl = batch["txt_sl"] if modality is None else modality == 0
_img_sl = batch["img_sl"] if modality is None else modality == 1
logits[..., self.text_vocab_size:] = torch.where(_txt_sl[..., None], self.neg_infinity, logits[..., self.text_vocab_size:])
logits[..., :self.text_vocab_size] = torch.where(_img_sl[..., None], self.neg_infinity, logits[..., :self.text_vocab_size])
else:
logits[..., self.static_txt_sl, self.text_vocab_size:] = self.neg_infinity
logits[..., self.static_img_sl, :self.text_vocab_size] = self.neg_infinity
# Normalize the logits such that x.exp() is
# a probability distribution over vocab_size.
logits = logits - torch.logsumexp(logits, dim=-1, keepdim=True)
if self.parameterization != "ar" and xt is not None:
# Apply updates directly in the logits matrix.
# For the logits of the unmasked tokens, set all values
# to -infinity except for the indices corresponding to
# the unmasked tokens.
unmasked_indices = xt != self.mask_index
if not self.allow_slicing:
logits = torch.where(unmasked_indices.unsqueeze(-1), torch.full_like(logits, self.neg_infinity), logits)
logits = torch.where(
unmasked_indices.unsqueeze(-1) & (torch.arange(logits.size(-1)).to(logits.device) == xt.unsqueeze(-1)),
torch.zeros_like(logits),
logits
)
else:
logits[unmasked_indices] = self.neg_infinity
logits[unmasked_indices, xt[unmasked_indices]] = 0
return logits
def _process_sigma(self, sigma):
if sigma is None:
assert (self.parameterization == "ar" or self.config.trainer.ar_llm_loss) or self.config.trainer.allow_null_sigma
return sigma
if sigma.ndim > 1 and not self.config.trainer.image_mode == "continuous":
sigma = sigma.squeeze(-1)
assert sigma.ndim == 1, sigma.shape
if not self.time_conditioning and getattr(self.config.model, "force_time_conditioning", False):
sigma = torch.zeros_like(sigma)
return sigma
def forward(
self,
x,
sigma,
batch=None,
forward_attention_mask=None,
return_additional_loss=False,
x_img_emb=None,
disable_ar_shift=False,
continuous_mode=False,
joint_ar_nar_mask=None,
return_logits=False,
block_mask=None,
update_cache_slice=None,
**kwargs,
):
"""Returns log score."""
sigma = self._process_sigma(sigma)
if self.config.trainer.image_mode == "continuous": assert "modality" in kwargs
should_autocast = (((self.config.trainer.disable_forward_autocast_during_eval and self.backbone.training) is False) and (self.dtype != torch.float32))
with ExitStack() as stack:
if should_autocast:
stack.enter_context(torch.autocast(device_type=self.device.type, dtype=self.dtype))
orig_modality = None
if self.config.backbone == "elm":
if getattr(self.config.trainer, "print_llm_ppl", False):
_labels = x.clone()
_labels[~forward_attention_mask] = -100
kwargs['labels'] = _labels
if "modality" in kwargs:
if self.config.mode == "eval": orig_modality = kwargs.pop("modality")
else: kwargs.pop("modality")
if "modality_mask" in kwargs: kwargs.pop("modality_mask")
if "x0" in kwargs: kwargs.pop("x0")
if "start_pos" in kwargs: kwargs.pop("start_pos")
if "sample_ids" in kwargs: kwargs.pop("sample_ids")
output = self.backbone(input_ids=x, **kwargs)
if self.config.mode == "eval": kwargs["modality"] = orig_modality
if isinstance(output, Tensor):
logits = output
else:
logits = output.logits
if getattr(self.config.trainer, "print_llm_ppl", False):
rprint(f"AR PPL: {torch.exp(output.loss)}")
else:
if self.config.trainer.compile == 'max-autotune' and not is_xla_available:
torch.compiler.cudagraph_mark_step_begin()
logits = self.backbone(x, sigma, continuous_mode=continuous_mode, x_img_emb=x_img_emb, block_mask=block_mask, update_cache_slice=update_cache_slice, **kwargs)
if self.config.trainer.force_bf16_eval:
logits = logits.to(torch.bfloat16)
if continuous_mode:
assert self.parameterization == "ar"
logits, logits_img = logits
if self.config.trainer.ar_shift and not disable_ar_shift:
# config trainer ar shift is for training
# disable ar shift is for sampling at inference
logits = logits[:, :-1]
xt = x[:, 1:]
if orig_modality is not None and self.config.mode == 'eval':
orig_modality = orig_modality[:, 1:]
else:
xt = x
if self.config.trainer.low_precision_loss:
logits = logits.to(self.dtype)
if continuous_mode:
logits_img = logits_img.to(self.dtype)
if self.parameterization == "planner":
return logits
elif self.config.trainer.ar_llm_loss:
assert not self.parameterization == "ar"
model_output = self._subs_parameterization(logits, xt=xt, modality=orig_modality), logits
if is_xla_available: shard_output(model_output[0], self.xla_mesh)
if is_xla_available: shard_output(model_output[1], self.xla_mesh)
return model_output if return_additional_loss else model_output[0]
elif self.parameterization == "ar":
if not getattr(self.config.trainer, "use_orig_unidisc_dit", False):
logits = torch.where(
torch.arange(logits.shape[-1], device=logits.device)[None, None, :] == self.mask_index, self.neg_infinity, logits
)
_modality = kwargs.get("modality") if batch is None else batch.get("modality")
# During eval, we let the sampler handle this part.
if getattr(self.config.model, "force_argmax_valid_indices", False) and _modality.shape[1] == (logits.shape[1] + 1):
if not self.allow_slicing:
logits = logits.clone()
logits[..., self.text_vocab_size:] = torch.where(
(kwargs.get("modality") == 0)[..., 1:, None], torch.finfo(logits.dtype).min, logits[..., self.text_vocab_size:]
)
logits[..., :self.text_vocab_size] = torch.where(
(kwargs.get("modality") == 1)[..., 1:, None], torch.finfo(logits.dtype).min, logits[..., :self.text_vocab_size]
)
logits = logits.log_softmax(-1)
if continuous_mode:
return (logits, logits_img)
elif self.parameterization == "subs":
if return_logits:
return logits
model_output = self._subs_parameterization(logits, xt=xt, batch=batch, **kwargs)
if is_xla_available: shard_output(model_output, self.xla_mesh)
return model_output
elif self.parameterization == "sedd":
return self._sedd_parameterization(logits=logits, xt=x, sigma=sigma)
elif self.parameterization == "d3pm":
return self._d3pm_parameterization(logits=logits)
return logits
def compute_loss(self, batch, prefix, batch_idx=-1):
if not is_xla_available and ((self.current_run_fwd_bwd_pass == 0 and self.config.mode == 'train') or batch_idx == 0):
self.visualize_samples(batch, batch_idx, split=prefix)
if getattr(self.config.trainer, 'overfit_on_first_batch', False):
if batch_idx <= 0:
# store it
self.overfit_batch = batch.copy()
else:
batch = self.overfit_batch
kwargs = self.get_cond_dict(batch)
modality_mask = batch.get("modality_mask", None)
(input_tokens, output_tokens, attention_mask) = self._maybe_sub_sample(batch["input_ids"], batch.get("attention_mask", None))
continuous_mode = self.config.trainer.image_mode == "continuous"
joint_ar_nar_mask, modality = None, None
if continuous_mode:
assert 'modality' in batch
x0, img_emb, attention_mask, modality = (
batch["input_ids"],
batch["img_emb"],
batch["attention_mask"],
batch["modality"],
) # img_emb has [0.] * txt_len + img_emb
xt = x0
B, N_tot, C = img_emb.shape
noise_scheduler = self.get_vae().scheduler
noise = torch.randn_like(img_emb)
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (B,), device=img_emb.device).long()
img_timesteps = timesteps.unsqueeze(-1).expand(-1, N_tot).to(self.dtype)
zero_timesteps = torch.zeros_like(img_timesteps)
unet_conditioning = torch.where(modality == 1, img_timesteps, zero_timesteps)
# unet_conditioning = timesteps.to(self.dtype)
# unet_conditioning = torch.where(modality_mask==1, timesteps.to(self.dtype), torch.zeros_like(timesteps.to(self.dtype)))
x_img_emb = noise_scheduler.add_noise(img_emb, noise, timesteps).to(self.dtype)
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(img_emb, noise, timesteps) # todo, might break
elif noise_scheduler.config.prediction_type:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
target = target.to(self.dtype)
else:
unet_conditioning, xt, x0, x_img_emb, modality_mask = None, None, input_tokens, None, batch.get("modality_mask", None)
if self.parameterization != "ar":
t = self._sample_t(x0.shape[0], x0.device)
if self.T > 0:
t = (t * self.T).to(torch.int)
t = t / self.T
t += 1 / self.T # t \in {1/T, 2/T, ..., 1}
if self.change_of_variables:
unet_conditioning = t[:, None]
f_T = torch.log1p(-torch.exp(-self.noise.sigma_max))
f_0 = torch.log1p(-torch.exp(-self.noise.sigma_min))
move_chance = torch.exp(f_0 + t * (f_T - f_0))
move_chance = move_chance[:, None]
else:
# total, rate
sigma, dsigma = self.noise(t)
unet_conditioning = sigma[:, None]
move_chance = 1 - torch.exp(-sigma[:, None])
xt, ignore_batch_mask_for_metrics, joint_ar_nar_mask, should_mask_txt, should_mask_img, move_indices = self.q_xt(x0, move_chance, return_ignore_batch_mask_for_metrics=True, batch=batch)
if (self.config.model.flex_attention_img_masking_prob is not None or self.config.model.flex_attention_txt_masking_prob is not None) and self.backbone.training:
assert xt.shape[1] == (self.config.model.img_length + self.config.model.txt_length)
txt_batch_attn_dropout = torch.rand(xt.shape[0], device=xt.device) < self.config.model.flex_attention_txt_masking_prob
img_batch_attn_dropout = torch.rand(xt.shape[0], device=xt.device) < self.config.model.flex_attention_img_masking_prob
# If we mask out a modality, we cannot let it only see itself
txt_batch_attn_dropout = txt_batch_attn_dropout & ~should_mask_txt.squeeze(-1)
img_batch_attn_dropout = img_batch_attn_dropout & ~should_mask_img.squeeze(-1)
kwargs['block_mask'] = get_block_mask(txt_batch_attn_dropout, img_batch_attn_dropout, self.config.model.txt_length, xt.shape[0], xt.shape[1], xt.device)
# TODO: Somehow report these metrics so we know what's going on
ignore_batch_mask_for_metrics = ignore_batch_mask_for_metrics | (txt_batch_attn_dropout | img_batch_attn_dropout).unsqueeze(-1)
if getattr(self.config.trainer, "interleaved_training_flex_attention", False):
kwargs['block_mask'] = get_interleaved_block_mask(batch["sample_ids"], batch_size=xt.shape[0], seq_len=xt.shape[1], device=xt.device)
kwargs['sample_ids'] = batch["sample_ids"]
elif self.config.trainer.ar_inpainting:
x0 = torch.cat([x0, x0], dim=1)
kwargs['modality'] = torch.cat([kwargs['modality'], kwargs['modality']], dim=1)
attention_mask = torch.cat([torch.zeros_like(attention_mask, dtype=attention_mask.dtype), torch.ones_like(attention_mask, dtype=attention_mask.dtype)], dim=1)
modality_mask = torch.cat([modality_mask, modality_mask], dim=1)
min_val, max_val = 0.0, 1.0
n = x0.shape[0]
_eps_t = torch.rand(n, device=self.device)
offset = torch.arange(n, device=self.device) / n
_eps_t = (_eps_t / n + offset) % 1
t = (max_val - min_val) * _eps_t + min_val
if getattr(self.config.eval, "ar_inpainting_force_val", None) is not None:
t = torch.full_like(t, getattr(self.config.eval, "ar_inpainting_force_val"), dtype=t.dtype, device=t.device)
move_indices = torch.rand(*x0.shape, device=x0.device) < t[:, None]
move_indices[:, x0.shape[1] // 2:] = False
x0 = torch.where(move_indices, self.mask_index, x0)
xt = x0
else:
xt = x0
if (self.training or getattr(self.config.trainer, "force_flip_ar_val", False)) and self.config.trainer.rand_ar_modality_dropout is not None:
assert not is_xla_available
xt = xt.clone()
batch_modality_dropout = torch.rand(xt.shape[0], device=xt.device) < self.config.trainer.rand_ar_modality_dropout
first_modality = batch["modality"][:, 0]
first_modality_mask = batch["modality"] == first_modality[:, None]
xt = torch.where(first_modality_mask & batch_modality_dropout[:, None], self.mask_index, xt)
attention_mask = torch.where(first_modality_mask & batch_modality_dropout[:, None], False, attention_mask)
true_logits = None
model_output = self.forward(
xt, unet_conditioning, return_additional_loss=True, batch=batch, x_img_emb=x_img_emb, joint_ar_nar_mask=joint_ar_nar_mask, **kwargs
)
if isinstance(model_output, tuple):
if continuous_mode:
model_output, img_output = model_output # model_output is for text, img_output is for image although both will have N_total length (zeroed out according to modality mask)
B, _, C = img_output.shape
# use modality mask to get the correct logits
x0 = x0[modality==0].reshape(B, -1)
xt = xt[modality==0].reshape(B, -1)
attention_mask = torch.ones_like(x0, dtype=torch.bool) # since we separate text, we don't need to mask it out
img_output = img_output[modality==1].reshape(B, -1, C)
target = target[modality==1].reshape(B, -1, C)
else:
model_output, true_logits = model_output
to_dtype = self.dtype if self.config.trainer.low_precision_loss else torch.float32
model_output = model_output.to(to_dtype)
if true_logits is not None:
true_logits = true_logits.to(self.dtype)
if continuous_mode:
img_output = img_output.to(to_dtype)
target = target.to(to_dtype)
# if prefix != 'train':
# breakpoint()
if self.config.trainer.ar_shift:
x0 = x0[:, 1:]
xt = xt[:, 1:]
attention_mask = attention_mask[:, 1:]
if modality_mask is not None: modality_mask = modality_mask[:, 1:]
if modality is not None: modality = modality[:, 1:]
if not self.is_compiled:
utils.print_nans(model_output, "model_output")
if self.parameterization == "sedd":
return dsigma[:, None] * self._score_entropy(model_output, sigma[:, None], xt, x0)
elif self.parameterization == "planner":
return F.binary_cross_entropy_with_logits(model_output.squeeze(-1), move_indices.float()).mean()
diffusion_loss = None
if self.T > 0:
diffusion_loss = self._d3pm_loss(model_output=model_output, xt=xt, x0=x0, t=t)
if self.parameterization == "d3pm":
reconstruction_loss = self._reconstruction_loss(x0)
elif self.parameterization == "subs" or self.parameterization == "ar":
reconstruction_loss = 0
# return reconstruction_loss + diffusion_loss
if self.parameterization == "ar":
if getattr(self.config.trainer, "use_orig_unidisc_dit", False):
return self.shortcut_return(model_output, x0, attention_mask, prefix)
else:
log_p_theta = model_output.gather(-1, x0[:, :, None])[:, :, 0]
else:
# SUBS parameterization, continuous time
log_p_theta = torch.gather(input=model_output, dim=-1, index=x0[:, :, None]).squeeze(-1)
if self.change_of_variables or self.importance_sampling:
return log_p_theta * torch.log1p(-torch.exp(-self.noise.sigma_min))
if self.parameterization == "ar" or getattr(self.config.trainer, "no_ce_weighting", False):
std_weighting = 1
else:
std_weighting = (dsigma / torch.expm1(sigma))[:, None]
# ddprint(f"self.current_run_fwd_bwd_pass: {self.current_run_fwd_bwd_pass}, log_p_theta: {torch.isnan(log_p_theta).any()}")
# if torch.isnan(log_p_theta).any() or self.current_run_fwd_bwd_pass > 15473:
# import pickle
# import time
# rank = get_rank()
# timestamp = int(time.time() * 1e9) # nanosecond timestep
# filename = f'batch_datastep_{self.current_run_fwd_bwd_pass}_rank{rank}_{timestamp}.pkl'
# with open(filename, 'wb') as f:
# pickle.dump(log_p_theta, f)
# ddprint(f"Saved batch to {filename}")
loss = -log_p_theta * std_weighting
if not (self.parameterization == "ar" or (self.config.trainer.ar_llm_loss and joint_ar_nar_mask is None) or getattr(self.config.trainer, "no_ce_weighting", False)):
gamma = getattr(self.config.trainer, "softmin_snr", None)
if gamma is not None:
softmin_weighting = (dsigma / (torch.expm1(sigma) + (1 / gamma)))[:, None]
loss = -log_p_theta * softmin_weighting
if diffusion_loss is not None:
assert self.T > 0
loss = diffusion_loss
std_loss = -log_p_theta * std_weighting
loss_dict = dict(std_loss=std_loss.detach(), extra_losses=dict())
if self.config.trainer.log_seperate_modal_losses:
assert not continuous_mode
loss_dict.update(
dict(
std_txt_loss=(std_loss.detach() * modality_mask[..., 0] * attention_mask),
std_img_loss=(std_loss.detach() * modality_mask[..., 1] * attention_mask)
)
)
if getattr(self.config.trainer, "mask_entire_modality", None) is not None and self.backbone.training and not self.config.parameterization == "ar":
loss_dict['batch_ignore_loss'] = ignore_batch_mask_for_metrics.squeeze(-1)
if joint_ar_nar_mask is not None:
if "batch_ignore_loss" in loss_dict:
loss_dict["batch_ignore_loss"] = loss_dict["batch_ignore_loss"] | joint_ar_nar_mask
else:
loss_dict["batch_ignore_loss"] = joint_ar_nar_mask
if (self.config.trainer.multimodal_batches or (self.config.trainer.text_loss_weight is not None and self.config.trainer.img_loss_weight is not None)) and not continuous_mode:
txt_mask = modality_mask[..., 0] & attention_mask
img_mask = modality_mask[..., 1] & attention_mask
txt_count = txt_mask.sum()
img_count = img_mask.sum()
total_count = txt_count + img_count
txt_frac = txt_count / total_count
img_frac = img_count / total_count
loss_dict["extra_losses"]["trainer/img_frac"] = img_frac
loss_dict["extra_losses"]["trainer/txt_frac"] = txt_frac
loss_dict["extra_losses"]["trainer/attention_mask_valid_frac"] = attention_mask.sum() / attention_mask.numel()
if "batch_ignore_loss" in loss_dict:
loss_dict["extra_losses"]["trainer/ignore_batch_metrics_frac"] = loss_dict["batch_ignore_loss"].sum() / loss_dict["batch_ignore_loss"].numel()
if joint_ar_nar_mask is not None:
pass # Defer loss mean until after ar_loss is calculated
elif self.config.trainer.text_loss_weight is not None and self.config.trainer.img_loss_weight is not None:
assert not continuous_mode
loss = loss * attention_mask
txt_loss = (
loss[txt_mask].sum() / txt_count
) * txt_frac * self.config.trainer.text_loss_weight
img_loss = (
loss[img_mask].sum() / img_count
) * img_frac * self.config.trainer.img_loss_weight
if getattr(self.config.trainer, "set_max_txt_loss_ratio", None) is not None and not (torch.isnan(img_loss).any() or torch.isnan(txt_loss).any()):
max_txt_loss = getattr(self.config.trainer, "set_max_txt_loss_ratio", 1.5) * img_loss.detach()
scale = torch.minimum(torch.tensor(1.0, device=txt_loss.device), max_txt_loss / (txt_loss.detach() + 1e-8))
txt_loss = txt_loss * scale
txt_loss = torch.nan_to_num(txt_loss, nan=0.0)
img_loss = torch.nan_to_num(img_loss, nan=0.0)
if getattr(self.config.trainer, "force_remove_img_tokens", False):
img_loss = torch.tensor(0, device=loss.device, dtype=loss.dtype)
loss = txt_loss + img_loss
loss_dict.update(dict(txt_loss=txt_loss.clone().detach(), img_loss=img_loss.clone().detach()))
elif continuous_mode:
img_loss = F.mse_loss(img_output, target)
if attention_mask[:, self.static_txt_sl].numel() == 0:
# Let grads pass even though this is zeros...
txt_loss = (loss[:, self.static_txt_sl] * attention_mask[:, self.static_txt_sl]).sum()
else:
txt_loss = (loss[:, self.static_txt_sl] * attention_mask[:, self.static_txt_sl]).sum() / attention_mask[:, self.static_txt_sl].sum()
loss = txt_loss + img_loss * self.config.trainer.image_loss_weight
loss_dict.update(dict(img_loss=img_loss.clone().detach(), txt_loss=txt_loss.clone().detach()))
else:
_attention_mask = torch.ones_like(attention_mask) if getattr(self.config.trainer, "force_full_attention_mask_loss_only", False) else attention_mask
loss = (loss * _attention_mask).sum() / _attention_mask.sum()
loss = torch.nan_to_num(loss, nan=0.0)
ar_loss = None
if self.config.trainer.ar_llm_loss:
assert not continuous_mode
valid_loss = xt == self.mask_index
_labels = x0.clone()
_labels = torch.where(valid_loss, _labels, -1)
_labels = torch.where(~attention_mask.to(torch.bool), -1, _labels)
_logits = true_logits
_logits[:, :, self.mask_index] += self.neg_infinity
if getattr(self.config.model, "force_argmax_valid_indices", False):
assert not self.config.trainer.multimodal_batches
_logits[:, self.static_txt_sl, self.text_vocab_size:] = torch.finfo(_logits.dtype).min
_logits[:, self.static_img_sl, : self.text_vocab_size] = torch.finfo(_logits.dtype).min
_logits = _logits.contiguous().view(-1, _logits.shape[-1])
_labels = _labels.contiguous().view(-1)
if self.config.trainer.ar_print_loss:
_labels = _labels.to(_logits.device)
ce_loss = loss_fct(_logits, _labels)
loss_fct = nn.CrossEntropyLoss(reduction='none')
ce_loss = ce_loss.mean(dim=-1)
if hasattr(self, 'histogram') is False:
self.histogram = {}
update_histogram(self.histogram, t, ce_loss)
rprint(f"ELM loss: move: {move_chance}, t:{t}, {ce_loss}")
loss_fct = nn.CrossEntropyLoss(ignore_index=-1, reduction='none' if joint_ar_nar_mask is not None else 'mean')
ce_loss = loss_fct(_logits, _labels)
loss_dict["extra_losses"]["trainer/ce_loss"] = ce_loss
ar_loss = ce_loss
if joint_ar_nar_mask is not None:
__true_logits = true_logits.clone()
__true_logits = torch.where(torch.arange(true_logits.shape[-1], device=true_logits.device)[None, None, :] == self.mask_index, self.neg_infinity, __true_logits)
log_softmax = __true_logits.log_softmax(-1)
ar_loss = -log_softmax.gather(-1, x0[:, :, None])[:, :, 0]
assert ar_loss is not None
assert ar_loss.ndim == 2
assert loss.ndim == 2
ar_loss_weight = joint_ar_nar_mask.sum(dim=0) / joint_ar_nar_mask.shape[0]
nar_loss_weight = 1 - ar_loss_weight
loss_dict["extra_losses"]["trainer/ar_loss_weight"] = ar_loss_weight.detach().float()
loss_dict["extra_losses"]["trainer/nar_loss_weight"] = nar_loss_weight.detach().float()
loss_dict["extra_losses"]["trainer/ce_loss"] = ar_loss.mean().detach().float()
ar_loss = (ar_loss * ar_loss_weight) * attention_mask
nar_loss = (loss * nar_loss_weight) * attention_mask
valid_count = attention_mask.sum()
if not is_xla_available:
ar_valid_count = attention_mask[joint_ar_nar_mask].sum()
nar_valid_count = attention_mask[~joint_ar_nar_mask].sum()
loss_dict["extra_losses"]["trainer/ar_loss"] = (ar_loss[joint_ar_nar_mask].sum() / ar_valid_count).detach().float()
loss_dict["extra_losses"]["trainer/nar_loss"] = (loss[~joint_ar_nar_mask].sum() / nar_valid_count).detach().float()
loss_dict["extra_losses"]["trainer/ar_ppl"] = torch.exp(loss_dict["extra_losses"]["trainer/ar_loss"]).detach().float()
loss_dict["extra_losses"]["trainer/nar_ppl"] = torch.exp(loss_dict["extra_losses"]["trainer/nar_loss"]).detach().float()
loss = (torch.where(joint_ar_nar_mask[:, None], ar_loss, nar_loss).sum() / valid_count) + weighted_z_loss
elif ar_loss is not None:
loss = ar_loss
loss_dict = dict(loss=loss, **loss_dict)
std_loss = loss_dict.get("std_loss", 0)
std_nlls = std_loss * attention_mask
if "batch_ignore_loss" in loss_dict:
attention_mask = torch.where(loss_dict['batch_ignore_loss'][:, None].repeat(1, attention_mask.shape[-1]), torch.full_like(attention_mask, False), attention_mask)
losses = Loss(
loss=loss_dict["loss"],
img_loss=loss_dict.get("img_loss", 0),
txt_loss=loss_dict.get("txt_loss", 0),
nlls=std_nlls,
txt_nlls=loss_dict.get("std_txt_loss", 0),
img_nlls=loss_dict.get("std_img_loss", 0),
token_mask=attention_mask,
modality_mask=modality_mask,
extra_losses=loss_dict.get("extra_losses", None),
)
if getattr(self.config.trainer, "disable_torchmetrics", False):
raise NotImplementedError("Torchmetrics disabled")
elif prefix == "train":
return losses
elif prefix == "val":
self.valid_metrics.update(losses.nlls, losses.token_mask)
if hasattr(self, "valid_txt_metrics"):
self.valid_txt_metrics.update(losses.txt_nlls, losses.modality_mask[..., 0] & losses.token_mask)
self.valid_img_metrics.update(losses.img_nlls, losses.modality_mask[..., 1] & losses.token_mask)
elif prefix == "test":
self.test_metrics.update(losses.nlls, losses.token_mask)
metrics = self.test_metrics
self.log_dict(metrics, on_step=False, on_epoch=True, sync_dist=True)
else:
raise ValueError(f"Invalid prefix: {prefix}")
@torch.no_grad()
def zero_shot_eval(self):
dataloader = self.validation_dataloader
total_batches = len(dataloader)
rprint(f"Zero shot eval with {total_batches} batches with limit_val_batches: {self.config.trainer.limit_val_batches}")
for idx, batch in tqdm(enumerate(dataloader), total=total_batches, desc="Zero shot eval validation steps", disable=not is_main_process()):
if self.config.trainer.limit_val_batches is not None and idx >= self.config.trainer.limit_val_batches:
break
self.zero_shot_eval_step(batch, idx)
self.zero_shot_eval_epoch_end()
def validate(self, state: TrainingState):
self.on_validation_epoch_start()
if getattr(self.config.eval, "compute_val_metrics_standalone", False) and getattr(self.config.eval, "bypass_normal_validation", False):
batch = next(iter(self.validation_dataloader))
self.on_validation_epoch_end(example_batch=batch)
self.on_validation_epoch_cleanup()
return
total_len = 10 if self.config.data.iterable or self.config.data.webdataset_indexed else len(self.validation_dataloader)
dprint(f"Validation batches: {total_len}")
total_batches = (
self.config.trainer.limit_val_batches
if (self.config.trainer.limit_val_batches is not None and self.fid_eval is False)
else total_len
)
if getattr(self.config.eval, 'pplx_full_dataset', False):
rprint("[INFO] PPLX full dataset eval, setting total_batches to total_len")
total_batches = total_len
elif self.config.eval.max_num_fid_batches_per_device is not None and self.fid_eval:
total_batches = min(total_len, self.config.eval.max_num_fid_batches_per_device)
_dataloader = self.train_dataloader if self.config.eval.val_with_train_data else self.validation_dataloader
rprint(f"Validating with {total_batches} batches on {self.world_size} GPUs with batch size {self.config.loader.eval_batch_size}")
for idx, batch in tqdm(enumerate(_dataloader), total=total_batches, desc="Validation steps", disable=not is_main_process()):
if self.config.trainer.limit_val_batches is not None and idx >= total_batches:
break
self.validation_step(batch, idx)
if getattr(self.config.eval, "eval_large_batch", None) is not None:
assert isinstance(batch, TensorDict)
dataloader_iter = iter(_dataloader)
large_batch = [next(dataloader_iter, None) for _ in range(getattr(self.config.eval, "eval_large_batch", None))]
large_batch = [b for b in large_batch if b is not None]
large_batch = torch.stack(large_batch, dim=0)
batch = large_batch
gprint(f"Large batch shape: {batch.shape}")
else:
batch = next(iter(_dataloader))
if self.config.eval.visualize_data_only:
return
if self.config.eval.compute_standalone_mauve and not getattr(self.config.eval, "global_disable_mauve", False):
self.mauve_store_references(_dataloader)
if self.config.mode == "eval":
gprint(f"Batch shape: {batch['input_ids'].shape}")
self.on_validation_epoch_end(example_batch=batch)
self.on_validation_epoch_cleanup()
@cached_property
def global_batch_size(self):
"""Batch size for a single step over all GPUs"""
# SPMD treats all ranks [regardless of node] as a single device
return self.step_batch_size * (1 if (self.config.trainer.xla_spmd and is_xla_available) else self.world_size)
@cached_property
def step_batch_size(self):
"""Batch size for a single step for a single GPU"""
return self.config.loader.batch_size * self.config.trainer.accumulate_grad_batches
@cached_property
def world_size(self):
"""Number of GPUs over all nodes"""
return get_world_size()
@cached_property
def num_tokens_per_sample(self):
"""Number of tokens per sample"""
return self.config.model.length
@cached_property
def gradient_accumulation_steps(self):
"""Number of gradient accumulation steps"""
return self.config.trainer.accumulate_grad_batches
@cached_property
def static_txt_sl(self):
return slice(None, self.config.model.txt_length)
@cached_property
def static_img_sl(self):
return slice(-self.config.model.img_length, None)
def img_txt_pair_batch_mask(self, batch=None):
return batch["modality_mask"][..., 1].sum(dim=-1) > 0
def txt_sl(self, batch=None):
return batch["modality_mask"][..., 0]
def img_sl(self, batch=None):
return batch["modality_mask"][..., 1]
@cached_property
def is_compiled(self):
return is_xla_available or self.config.trainer.compile
@property
def allow_slicing(self):
return not is_xla_available and not self.backbone.training
@property
def training(self):
return self.backbone.training
def get_step_metrics(self):
return {
"trainer/global_step": self.global_step,
"global_samples": self.global_step * self.global_batch_size,
"train_metrics/global_tokens": self.global_step * self.global_batch_size * self.config.model.length,
"effective_global_tokens": self.global_step * self.global_batch_size * self.config.model.length * (0.5 if self.config.parameterization == "subs" else 1.0),
"effective_global_step": int(self.global_step * (0.5 if self.config.parameterization == "subs" else 1.0)),
}
def train(self):
tr = self.config.trainer
total_batch_size = self.global_batch_size
initial_global_step = self.global_step
true_step = 0
first_epoch = 0
self.current_run_global_step = 0
self.current_run_fwd_bwd_pass = 0
rprint(f"Started at step {self.accelerator.step}")
if self.non_embedding_params < 1e9:
with try_except(write_error_to_file=True, clear_cuda_cache=True):
self.print_hashes()
# There is an unknown bug with accelerator where non-master ranks don't load the step count from a checkpoint.
# We workaround by broadcasting the step count if necessary
if is_torch_cuda_available():
dprint(f"Gathering step from {self.world_size} ranks")
starting_steps = gather_object([self.accelerator.step])
rprint(f"Starting steps: {starting_steps}")
if not all([x > 0 for x in starting_steps]):
rprint(f"Not all ranks have >0 step, setting to: {starting_steps[0]}")
self.accelerator.step = starting_steps[0]
if is_xla_available:
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
assert (self.config.trainer.accumulate_grad_batches == 1) or getattr(self.config.trainer, "allow_accum_grad_batches_xla", False), "Accumulate grad batches must be 1 for XLA"
rprint(f"***** Starting training at global step: {self.global_step} *****")
rprint(f" Instantaneous batch size per device = {self.config.loader.batch_size}")
rprint(f" Gradient Accumulation steps = {tr.accumulate_grad_batches}")
rprint(f" Num GPUs = {tr.devices}")
rprint(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
rprint(f" Total optimization steps = {tr.max_steps}")
rprint(f" Reported Global Batch Size: {self.global_batch_size}, Reported Step Batch Size: {self.step_batch_size}, Reported World Size: {self.world_size}")
if not self.config.data.iterable and not self.config.data.webdataset_indexed and is_torch_cuda_available():
num_epoch_steps = len(self.train_dataloader)
rprint(f" Num examples = {len(self.train_dataloader.dataset)}")
rprint(f" Num batches each epoch = {len(self.train_dataloader)}")
rprint(f"Train Dataloader Size on single GPU: {num_epoch_steps}")
if len(self.train_dataloader.dataset) < total_batch_size:
rprint("The training dataloader is smaller than the total batch size. This may lead to unexpected behaviour.")
else:
num_epoch_steps = 10000
if self.config.trainer.pytorch_profile:
profiler = Profiler(
output_dir=self.config.output_dir, warmup_steps=tr.profiler_warmup_steps, active_steps=tr.profiler_active_steps, record_memory=True
)
if self.config.trainer.viz_images_only:
return self.viz_images_from_dataloader()
progress_bar = tqdm(range(0, tr.max_steps), initial=initial_global_step, desc="Steps", disable=not is_local_main_process(), leave=False, smoothing=0.15)
global_step_metrics = defaultdict(float)
global_extra_wandb_metrics = dict()
accumulate_steps = 0
first_start_time = time.time()
self.on_train_start()
rprint(f"Training for {tr.num_epochs} epochs...")
last_end_step_time = start_timing(f"Dataloading accum:{accumulate_steps}, #{true_step}, global_step:{self.global_step}")
for epoch in range(first_epoch, tr.num_epochs):
rprint(f"Starting epoch {epoch}...")
for step, batch in enumerate(self.train_dataloader):
ddprint(f"Data Step: {step}")
if self.config.trainer.iterate_dataloader_only:
rprint(f"Iterating dataloader only: {step}")
# rprint((batch["modality"] == 0).sum(), (batch["modality"] == 1).sum())
if (batch["attention_mask"] == 0).all(dim=-1).any():
breakpoint()
batch = self.update_batch(batch)
if (batch["sample_ids"] == -1).all(dim=-1).any():
breakpoint()
continue
elif getattr(self.config.trainer, "iterate_dataloader_n_dataloader_batches", None) is not None and step <= self.config.trainer.iterate_dataloader_n_dataloader_batches:
self.current_run_fwd_bwd_pass += 1
if self.current_run_fwd_bwd_pass % self.config.trainer.accumulate_grad_batches == 0:
self.global_step += 1
self.current_run_global_step += 1
ddprint(f"Iterating dataloader only for {self.config.trainer.iterate_dataloader_n_dataloader_batches} dataloader batches. At step {self.global_step=}, {self.current_run_global_step=}, {self.current_run_fwd_bwd_pass=}")
continue
if self.config.trainer.tpu_force_mark_step: xm.mark_step()
if self.config.trainer.sync_dataloader_timing: synchronize_device()
global_step_metrics[f"dataloading_time"] += end_timing(last_end_step_time)
if self.config.trainer.nvtx_profile and self.is_compiled and step == 4:
torch.cuda.cudart().cudaProfilerStart()
if self.current_run_global_step == 1 and is_xla_available:
gprint(f"First start time: {time.time() - first_start_time}")
if getattr(self.config.data, "force_dummy_tensordict", False):
gprint(self.global_step, self.current_run_global_step, true_step, batch["idx"].tolist(), batch["dataset_idx"].tolist())
if getattr(self.config.trainer, "assert_at_n_steps", None) is not None and self.global_step == self.config.trainer.assert_at_n_steps:
gprint(batch["img_input_ids"].min(), batch["img_input_ids"].max(), batch["txt_input_ids"].min(), batch["txt_input_ids"].max())
if batch is None:
rprint(f"Batch is None at step {step}")
continue
if self.config.trainer.tpu_force_mark_step: xm.mark_step()
ddprint(f"After Data Step 2: {step}")
with nullcontext() if is_xla_available else self.accelerator.accumulate(self.backbone):
ddprint(f"Before forward pass for global_step: {self.global_step}")
start_forward_time = start_timing(f"Forward Pass accum:{accumulate_steps}, #{true_step}, global_step:{self.global_step}")
global_step_metrics["examples_seen_per_gpu"] += len(next(iter(batch.values())))
state: TrainingState = TrainingState(
epoch_step=step,
num_epoch_steps=num_epoch_steps,
global_step=self.global_step,
epoch=epoch,
true_step=true_step,
current_run_global_step=self.current_run_global_step,
)
if self.accelerator.sync_gradients and is_xla_available is False:
self.cb_handler.on_train_step_start(state=state, unit=None)
if self.config.trainer.tpu_force_mark_step: xm.mark_step()
ddprint(f"Before Fwd: {step}")
with xp.StepTrace('Forward', step_num=step) if self.config.trainer.tpu_profile else nullcontext():
losses = self.training_step(batch, step)
ddprint(f"After Fwd: {step}")
global_step_metrics["forward_pass_time"] += end_timing(start_forward_time)
true_step += 1
evaluate_extra_log_data = lambda: dict()
if self.config.trainer.tpu_force_mark_step: xm.mark_step()
if isinstance(losses, dict):
for k, v in losses.items():
if isinstance(v, torch.Tensor):
global_step_metrics[k.removeprefix("metric_")] += v.detach().cpu().item()
else:
global_extra_wandb_metrics[k.removeprefix("metric_")] = v
losses = dict(
filter(lambda item: not item[0].startswith("metric_"), losses.items())
) # Allow for custom metrics that are not losses
loss = sum(losses.values())
elif isinstance(losses, Loss):
loss = losses.loss
metrics = self.train_metrics(losses.nlls, losses.token_mask)
if hasattr(self, "txt_metrics") and losses.modality_mask is not None:
txt_metrics = self.txt_metrics(losses.txt_nlls, losses.modality_mask[..., 0] & losses.token_mask)
if hasattr(self, "img_metrics") and losses.modality_mask is not None:
img_metrics = self.img_metrics(losses.img_nlls, losses.modality_mask[..., 1] & losses.token_mask)
extra_losses_dict = losses.extra_losses
extra_losses_dict = extra_losses_dict if extra_losses_dict is not None else dict()
if self.config.trainer.tpu_force_mark_step: xm.mark_step()
def evaluate_extra_log_data():
if hasattr(self, "txt_metrics"):
return {
**{f"train/txt_{k.split('/')[-1]}": v for k, v in replace_nan_dict(txt_metrics).items()},
**{f"train/img_{k.split('/')[-1]}": v for k, v in replace_nan_dict(img_metrics).items()},
}
else:
return {}
ddprint(f"Before loss: {step}")
incremental_dict_update(global_extra_wandb_metrics, {
"trainer/loss": loss,
"trainer/img_loss": losses.img_loss,
"trainer/txt_loss": losses.txt_loss,
**{
"global_samples": self.global_step * self.global_batch_size,
"train_metrics/global_tokens": self.global_step * self.global_batch_size * self.config.model.length,
"effective_global_tokens": self.global_step * self.global_batch_size * self.config.model.length * (0.5 if self.config.parameterization == "subs" else 1.0),
"effective_global_step": int(self.global_step * (0.5 if self.config.parameterization == "subs" else 1.0)),
},
**metrics,
**extra_losses_dict,
})
if self.config.trainer.tpu_force_mark_step: xm.mark_step()
else:
loss = losses
if is_torch_cuda_available():
global_step_metrics["loss"] = loss.detach().cpu().item() # Only on the main process to avoid syncing
ddprint(f"Before backward pass for global_step: {self.global_step}")
# Short-circuit to avoid XLA eval
if tr.backward_pass and (is_xla_available or torch.isfinite(loss).all()):
start_backward_time = start_timing(f"Backward Pass accum:{accumulate_steps}, #{true_step}, global_step:{self.global_step}")
if self.accelerator.sync_gradients:
start_sync_time = start_timing(f"Gradient Sync global_step:{self.global_step}")
if getattr(self.config.trainer, "sync_timing", False):
sync_times(self.device)
if self.config.trainer.tpu_force_mark_step: xm.mark_step()
# After each fwd, we perform a bwd. However, if we are accumulating there is an internal no_sync so the gradients remain on the GPU until
# the final bwd before a step. This can be controlled by sync_each_batch. Note that for the last bwd, the sync happens inside the bwd call below, so any timing for stragglers needs to happen before this call.
with xp.StepTrace('Backward', step_num=step) if self.config.trainer.tpu_profile else nullcontext():
ddprint(f"Before accelerator.backward for global_step: {self.global_step}")
self.accelerator.backward(loss)
ddprint(f"After accelerator.backward for global_step: {self.global_step}")
with xp.StepTrace('After Backward + Clip', step_num=step) if self.config.trainer.tpu_profile else nullcontext():
if self.accelerator.sync_gradients:
ddprint(f"Before after.backward for global_step: {self.global_step}")
self.after_backward(state)
if tr.gradient_clip_val is not None:
ddprint(f"Before self.accelerator.clip_grad_norm_ for global_step: {self.global_step}")
total_grad_norm = self.accelerator.clip_grad_norm_(self.backbone.parameters(), tr.gradient_clip_val)
ddprint(f"After self.accelerator.clip_grad_norm_ for global_step: {self.global_step}")
with xp.StepTrace('Optimizer + Scheduler Step', step_num=step) if self.config.trainer.tpu_profile else nullcontext():
ddprint(f"Before optimizer step for global_step: {self.global_step}, {step}")
if is_xla_available and False:
# TODO: xm.optimizer_step(self.optimizer) does not appear to be needed for XLA
xm.optimizer_step(self.optimizer)
else:
self.optimizer.step()
ddprint(f"After optimizer step for global_step: {self.global_step}, {step}")
self.lr_scheduler.step()
ddprint(f"After lr_scheduler step for global_step: {self.global_step}, {step}")
zero_grad_kwargs = dict()
if "apex" not in self.config.trainer.optimizer_cls:
zero_grad_kwargs["set_to_none"] = tr.set_grads_to_none
ddprint(f"Before zero_grad for global_step: {self.global_step}, {step}")
self.optimizer.zero_grad(**zero_grad_kwargs)
ddprint(f"Zeroed gradients for global_step: {self.global_step}, {step}")
if self.accelerator.sync_gradients:
if self.ema is not None:
if self.config.trainer.use_custom_ema:
ema_update(self.unwrap_model(self.ema), self.unwrap_model(self.backbone), self.config.trainer.ema)
else:
self.ema.step(self.get_params())
global_step_metrics["gradient_sync_time"] += end_timing(start_sync_time)
global_step_metrics["backward_pass_time"] += end_timing(start_backward_time)
else:
if not torch.isfinite(loss).all(): gprint(f"Loss is not finite: {loss}")
gprint("Skipping backward pass!")
accumulate_steps += 1
self.current_run_fwd_bwd_pass += 1
# Important: A single "global_step" is a single optimizer step. The accumulate decorator silently skips backward + optimizer to allow for gradient accumulation.
# A "true_step" counts the number of forward passes (on a per-GPU basis). The condition below should only happen immediately after a backward + optimizer step.
ddprint(f"Syncing gradients for global_step: {self.global_step}. Should sync: {self.accelerator.sync_gradients}, {step}, {self.accelerator.step}, {self.accelerator.gradient_accumulation_steps}")
if self.accelerator.sync_gradients:
start_gradient_sync_time = start_timing(f"On Sync Gradients global_step:{self.global_step}, {step}")
ddprint(f"Before on_train_step_end for global_step: {self.global_step}, {step}, {self.accelerator.step}, {self.accelerator.gradient_accumulation_steps}")
state.batch = batch
del loss, losses, batch
gradient_sync_time_after_train_step_end_time = start_timing(f"On Sync Gradients global_step:{self.global_step}, {step}")
self.on_train_step_end(state)
ddprint(f"After on_train_step_end for global_step: {self.global_step}, {step}, {self.accelerator.step}, {self.accelerator.gradient_accumulation_steps}")
global_step_metrics["gradient_sync_time_after_train_step_end"] += end_timing(gradient_sync_time_after_train_step_end_time)
if is_xla_available and self.config.trainer.tpu_force_mark_step: xm.mark_step()
if self.config.trainer.profile_memory and self.global_step + 1 >= tr.max_steps:
rprint("Finished profiling memory...")
break
if self.config.trainer.pytorch_profile and profiler.step(self.global_step):
rprint(f"Profiling finished at step: {self.global_step}")
break
if getattr(self.config.trainer, "throw_failure_for_testing", False) and self.current_run_global_step == 5:
raise RuntimeError("Test failure")
if is_xla_available and self.config.trainer.tpu_force_mark_step: xm.mark_step()
progress_bar.update(1)
self.global_step += 1
self.current_run_global_step += 1
global_step_metrics["gradient_sync_time"] += end_timing(start_gradient_sync_time)
logs = {
"examples_seen": self.global_step * total_batch_size,
"trainer/global_step": self.global_step,
**{k:v for k, v in global_step_metrics.items()},
**{f"lr_{i}": lr for i, lr in enumerate(self.lr_scheduler.get_last_lr())},
**global_extra_wandb_metrics,
}
if is_torch_cuda_available():
logs["gpu_max_mem_reserved_gb"] = torch.cuda.max_memory_reserved() / (1024**3)
logs["gpu_cur_mem_reserved_gb"] = torch.cuda.memory_reserved() / (1024**3)
logs["gpu_max_mem_allocated_gb"] = torch.cuda.max_memory_allocated() / (1024**3)
logs["gpu_cur_mem_allocated_gb"] = torch.cuda.memory_allocated() / (1024**3)
if is_xla_available:
if self.global_step % getattr(self.config.trainer, "log_every_n_steps", 1) == 0:
xm.add_step_closure(update_logs, args=(logs, evaluate_extra_log_data), run_async=False)
del logs
global_extra_wandb_metrics = dict()
if self.config.trainer.tpu_force_mark_step: xm.mark_step()
else:
logs.update(evaluate_extra_log_data())
progress_bar.set_postfix(**logs)
log(logs)
global_extra_wandb_metrics = dict()
if getattr(self.config.trainer, "sync_timing", False):
global_step_metrics = {f"rank_{get_rank()}/{k}": v for k, v in global_step_metrics.items()}
all_step_metrics = self.accelerator.gather_for_metrics([global_step_metrics], use_gather_object=True)
merged_metrics = {k: v for d in all_step_metrics for k, v in d.items()}
log(merged_metrics)
if is_xla_available and self.config.trainer.tpu_force_mark_step: xm.mark_step()
global_step_metrics = defaultdict(float)
accumulate_steps = 0
if self.global_step >= tr.max_steps:
break
ddprint(f"After logging for step v3: {self.global_step}, {step}")
if getattr(self.config.trainer, "assert_at_n_steps", None) is not None and self.current_run_global_step >= getattr(self.config.trainer, "assert_at_n_steps", None):
raise RuntimeError(f"Assertion failed at step {self.current_run_global_step}")
ddprint(f"After logging for step v4: {self.global_step}, {step}")
if is_xla_available and self.config.trainer.tpu_profile and (self.global_step == 0 or self.global_step % 50 == 0) and is_main_process():
import torch_xla.debug.metrics as met
rprint(met.metrics_report())
met.clear_all()
if is_xla_available and self.config.trainer.tpu_force_mark_step: xm.mark_step()
ddprint(f"Finished sync_gradients: {self.global_step}, {self.accelerator.step}, {self.accelerator.gradient_accumulation_steps}")
ddprint(f"Finished step: {self.global_step},{step},{self.accelerator.step},{self.accelerator.gradient_accumulation_steps},{self.accelerator.gradient_state.__repr__()}")
if self.config.trainer.sync_dataloader_timing: synchronize_device()
last_end_step_time = start_timing(f"Dataloading #{true_step + 1}")
if self.global_step >= tr.max_steps:
break
dprint(f"Finished epoch: {epoch}")
# Create the pipeline using using the trained modules and save it.
rprint("Training finished.")
barrier()
if tr.profile_memory:
print_memory(verbose=True)
save_memory_profile(self.config.output_dir / "profile")
if tr.pytorch_profile:
profiler.finish()
elif tr.nvtx_profile:
torch.cuda.cudart().cudaProfilerStop()
elif self.global_step > 100 or tr.skip_early_checkpointing is False:
self.checkpoint(state)
barrier()