unidisc / configs /experiments /small_scale_train.yaml
aswerdlow's picture
Initial commit
131da64
# @package _global_
defaults:
- vq16_magvit
- override /model: small
- override /lr_scheduler: constant_warmup_cosine_decay
model:
img_length: ${eval:'(${data.resolution} // ${model.downscale_ratio})**2'}
txt_length: ${eval:'${data.block_size} if ${.unified_model} else 0'}
length: ${eval:'${.txt_length} + ${.img_length}'}
image_model: true
text_model: true
unified_model: true
image_model_fid_eval: false
force_argmax_valid_indices: true
use_pretrained_img_emb: false
codebook_embed_dim: 256
qk_norm: true
norm_type: rms
sandwich_normalization: true
zero_linear_init: false
modality_embed: true
rope_2d: false
use_spda_attn: true
force_optimized_native_attn: true
freeze_txt_emb: false
add_labels: null
txt_dropout: null
text_vocab_size: 32001
data:
train: combined_tokens
valid: ${.train}
n_duplicate_train: null
wrap: true
streaming: false
precache: false
tokenizer_name_or_path: NousResearch/Llama-2-7b-hf
resolution: 256
block_size: 128
n_val_samples: null
unpaired: false
n_duplicate_val: null
save_train_dataloader: true
save_validation_dataloader: true
iterable: false
webdataset_iterable: false
webdataset_indexed: false
dataset_type: null
tokens_flip_collate: false
n_train_samples: null
raw_data_dir: null
tokenizers_parallelism: false
token_data_dir: null
force_disable_shuffle: false
keep_tensordict_on_disk: true
use_custom_tensordict_collate: true
force_mp_spawn: false
enable_cuda_in_tensordict_collate: false
use_weighted_tensordict_sampler: true
fraction_txt_data: 0.0
tokenize_vqvae_in_dataloader: false
use_token_dataset: true
image_dataset: tglcourse/lsun_church_train
image_data_train: null
image_data_val: null
keep_hf_dataset_in_memory: true
allow_label: false
disable_text_modality: true
force_raw_train_images: false
aggressive_aug: true
allow_aug_vqvae_dataloader: true
move_tensordict_to_shm: false
data_dir_train:
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/datacomp1b_8_magvit
weight: -1
name: datacomp1b_8_magvit_train
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/cc12m_tokens_train_256
weight: -1
name: cc12m_tokens_train_256
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/HPDv2_image_reward_v1_v2_v3_magvit
weight: -1
name: HPDv2_image_reward_v1_v2_v3_magvit
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/pick_score_sac_prompts_v1_v2_v3_magvit
weight: -1
name: pick_score_sac_prompts_v1_v2_v3_magvit
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/datacomp1b_0_1_6_magvit
weight: -1
name: datacomp1b_0_1_6_magvit
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/laion400m_magvit_part_0
weight: -1
name: laion400m_magvit_part_0
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/laion400m_magvit_part_1
weight: -1
name: laion400m_magvit_part_1
data_dir_val:
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/datacomp1b_8_magvit_val
weight: 1
name: datacomp1b_8_magvit_val
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/cc12m_tokens_val_256
weight: 1
name: cc12m_tokens_val_256
eval:
generate_samples: true
compute_generative_perplexity: true
log_every_n_evals: 10
log_every_n_fid: 20
limit_val_batches_manual: 16
perplexity_batch_size: ${loader.eval_batch_size}
num_masking_viz_batches: -1
cfg: null
class_conditional_fid: false
force_cfg_value: true
split_cfg_batches: true
max_num_fid_batches_per_device: ${eval:'8192 // (${trainer.devices} * ${loader.eval_batch_size})'}
fid_mode: clean
clean_fid_precomputed_name: lsun_church
clean_fid_precomputed_split: trainfull
clean_fid_precomputed_res: 256
trainer:
log_every_n_steps: 10
val_check_interval: 1000
custom_ddp_bf16: true
scale_lr_by_batch_size: false
limit_val_batches: 16
use_gradient_checkpointing: false
log_seperate_modal_losses: true
softmin_snr: 5
text_loss_weight: 1.0
img_loss_weight: null
low_precision_loss: false
compile: true
multimodal_batches: true
compile_fullgraph: false
log_grad_norm_every_n_steps: 10
mask_entire_modality: 0.1
force_shift_image_batches: false
ckpt_steps: 10000
ckpt_every_n_minutes: -1
ignore_text_in_unified: false
disable_all_eval_generation: true
eval_on_start: false
ckpt_model_only: false
ema: 0.0
use_custom_ema: false
log_flops: false
disable_distributed_torchmetrics: true
restart_on_failure: true
force_null_sigma: true
allow_null_sigma: true
compile_flag_pos_emb: true
add_label: false
first_token_dropout: null
force_shift_raw_image_batches: true
txt_dropout: 0.1
force_full_attention_mask_loss_only: true
optim:
lr: 0.0003
weight_decay: 0.05
loader:
batch_size: 64
eval_batch_size: ${loader.batch_size}
num_workers: 4
desired_global_batch_size: 512
persistent_workers: true
pin_memory: true
num_eval_workers: 1
sampling:
steps: ${model.length}
num_sample_batches: 2
max_sampling_steps: ${model.length}
wandb:
mode: online
lr_scheduler:
num_warmup_steps: 5000
num_training_steps: ${trainer.max_steps}
checkpointing:
checkpoints_total_limit: 10