|
import os |
|
import sys |
|
from contextlib import ExitStack |
|
from pathlib import Path |
|
|
|
from constants import CONFIG_PATH, LIB_DIR |
|
sys.path.append(str(LIB_DIR / "hydra_submitit_launcher")) |
|
|
|
import builtins |
|
import random |
|
import re |
|
import signal |
|
import traceback |
|
from copy import deepcopy |
|
from datetime import datetime |
|
|
|
import hydra |
|
import numpy as np |
|
import omegaconf |
|
from hydra.core.hydra_config import HydraConfig |
|
from omegaconf import DictConfig, OmegaConf, open_dict, read_write |
|
from safetensors.torch import load_file, save_file |
|
|
|
import dataloader |
|
from model import Diffusion |
|
import utils |
|
import wandb |
|
from decoupled_utils import (check_gpu_memory_usage, get_hostname, |
|
get_local_rank, get_rank, get_slurm_filename_info, |
|
get_slurm_log_prefix, get_tpu_devices, |
|
get_world_size, gprint, is_local_main_process, |
|
is_main_process, is_torch_cuda_available, |
|
is_torch_xla_available, print_params, |
|
process_file_prefix, profile_memory, rank_zero_fn, |
|
rprint, set_global_breakpoint, set_global_exists, |
|
set_timing_builtins, try_except) |
|
from utils import (ErrorHandler, _print_config, convert_state_dict_keys, set_omega_conf_resolvers, set_torch_defaults) |
|
|
|
|
|
|
|
|
|
set_global_breakpoint() |
|
set_global_exists() |
|
set_omega_conf_resolvers() |
|
|
|
if is_torch_xla_available(): |
|
from jax_smi import initialise_tracking |
|
|
|
def _load_from_checkpoint(config, tokenizer): |
|
OmegaConf.resolve(config) |
|
if "hf" in config.backbone: |
|
return Diffusion(config=config, tokenizer=tokenizer).to("cuda") |
|
|
|
return Diffusion.load_from_checkpoint(config.eval.checkpoint_path, tokenizer=tokenizer, config=config) |
|
|
|
@rank_zero_fn |
|
def _print_batch(train_ds, valid_ds, tokenizer, k=64): |
|
for dl_type, dl in [("train", train_ds), ("valid", valid_ds)]: |
|
rprint(f"Printing {dl_type} dataloader batch.") |
|
batch = next(iter(dl)) |
|
rprint("Batch input_ids.shape", batch["input_ids"].shape) |
|
first = batch["input_ids"][0, :k] |
|
last = batch["input_ids"][0, -k:] |
|
rprint(f"First {k} tokens:", tokenizer.decode(first)) |
|
rprint("ids:", first) |
|
rprint(f"Last {k} tokens:", tokenizer.decode(last)) |
|
rprint("ids:", last) |
|
|
|
|
|
def generate_samples(config, tokenizer): |
|
rprint("Generating samples.") |
|
model = _load_from_checkpoint(config=config, tokenizer=tokenizer) |
|
model.gen_ppl_metric.reset() |
|
if config.eval.disable_ema: |
|
rprint("Disabling EMA.") |
|
model.ema = None |
|
stride_length = config.sampling.stride_length |
|
num_strides = config.sampling.num_strides |
|
for _ in range(config.sampling.num_sample_batches): |
|
if config.sampling.semi_ar: |
|
_, intermediate_samples, _ = model.restore_model_and_semi_ar_sample( |
|
stride_length=stride_length, num_strides=num_strides, dt=1 / config.sampling.steps |
|
) |
|
text_samples = intermediate_samples[-1] |
|
|
|
|
|
|
|
|
|
|
|
else: |
|
samples = model.restore_model_and_sample(num_steps=config.sampling.steps) |
|
text_samples = model.tokenizer.batch_decode(samples) |
|
model.compute_generative_perplexity(text_samples) |
|
|
|
rprint("Text samples:", text_samples) |
|
if not config.sampling.semi_ar: |
|
rprint("Generative perplexity:", model.gen_ppl_metric.compute()) |
|
return text_samples |
|
|
|
|
|
def instantiate_wandb(config, accelerator): |
|
if is_torch_xla_available(): |
|
gprint("Initializing wandb for XLA") |
|
if config.mode == 'eval': |
|
config.wandb.project = f"{config.wandb.project}-eval" |
|
elif config.mode == 'zero-shot-eval': |
|
config.wandb.project = f"{config.wandb.project}-zero-shot-eval" |
|
|
|
if config.wandb.group is not None: |
|
config.wandb.group = str(config.wandb.group) |
|
|
|
|
|
|
|
wandb_kwargs = dict(config.wandb) |
|
|
|
if getattr(config, "sweep_id", None) is not None: |
|
rprint(f"Setting Wandb group to {config.sweep_id}") |
|
wandb_kwargs["group"] = config.sweep_id |
|
del wandb_kwargs["project"] |
|
accelerator.init_trackers( |
|
config.wandb.project, config=OmegaConf.to_container(config, resolve=True, throw_on_missing=True), init_kwargs=dict(wandb=wandb_kwargs) |
|
) |
|
|
|
if getattr(config.trainer, "log_code", True) and is_main_process(): |
|
if "matrix" in get_hostname(): |
|
rprint(f"Not logging code to wandb on {get_hostname()}") |
|
else: |
|
rprint(f"Logging code to wandb from {Path(__file__).parent}") |
|
try: |
|
wandb.run.log_code( |
|
root=str(Path(__file__).parent), |
|
include_fn=lambda path: any(path.endswith(f) for f in (".py", ".yaml", ".yml", ".txt", ".md")), |
|
exclude_fn=lambda path, root: any(x in os.path.relpath(path, root) for x in ("output", "multirun", "logs", "wandb")), |
|
) |
|
except Exception as e: |
|
rprint(f"Failed to log code to wandb: {e}") |
|
|
|
with open_dict(config): |
|
try: |
|
config.wandb_url = wandb.run.get_url() |
|
wandb.define_metric("global_samples") |
|
wandb.define_metric("effective_global_tokens") |
|
wandb.define_metric("effective_global_step") |
|
wandb.define_metric("train_metrics/samples") |
|
wandb.define_metric("trainer/loss", step_metric="global_samples") |
|
except Exception as e: |
|
rprint(f"Failed to get wandb url: {e}") |
|
|
|
def instantiate_model(config, tokenizer): |
|
model = _load_from_checkpoint(config=config, tokenizer=tokenizer) |
|
if config.eval.disable_ema: |
|
rprint("Disabling EMA.") |
|
model.ema = None |
|
|
|
return model |
|
|
|
def gconf(config, attr): |
|
return getattr(config, attr, None) |
|
|
|
|
|
def has_ckpt(config, attr): |
|
return gconf(config, attr) is not None and utils.fsspec_exists(gconf(config, attr)) |
|
|
|
|
|
def set_env_vars(config): |
|
import torch |
|
hostname = __import__("socket").gethostname() |
|
rprint(f"Starting Training on {hostname}") |
|
import torch |
|
|
|
|
|
if not is_torch_xla_available(): |
|
try: |
|
|
|
|
|
import resource |
|
soft, hard = resource.getrlimit(resource.RLIMIT_MEMLOCK) |
|
resource.setrlimit(resource.RLIMIT_MEMLOCK, (hard, hard)) |
|
if is_local_main_process(): |
|
gprint(f"Successfully set RLIMIT_MEMLOCK to {hard}") |
|
except ValueError as e: |
|
rprint(f"Failed to set RLIMIT_MEMLOCK: {e}") |
|
except resource.error as e: |
|
rprint(f"Error setting RLIMIT_MEMLOCK: {e}") |
|
else: |
|
rprint(f"Not setting RLIMIT_MEMLOCK on XLA") |
|
|
|
if "matrix-3-28" in hostname or "matrix-3-26" in hostname: |
|
rprint(f"Disabling NCCL P2P") |
|
os.environ["NCCL_P2P_DISABLE"] = "1" |
|
|
|
if os.environ.get("TORCH_DISTRIBUTED_DEBUG", "") != "": |
|
assert False, f"TORCH_DISTRIBUTED_DEBUG is set to: {os.environ.get('TORCH_DISTRIBUTED_DEBUG')}. Please unset it as it starts a gloo backend." |
|
|
|
if config.model.use_spda_attn: |
|
os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "1" |
|
os.environ["TORCH_CUDNN_MHA_ENABLED"] = "1" |
|
rprint("Setting SPDA Flags") |
|
|
|
if config.trainer.detect_anomaly: |
|
torch.autograd.set_detect_anomaly(True) |
|
|
|
def update_config_before_resolution(config): |
|
import torch |
|
if hasattr(config, "training"): |
|
rprint(f"'training' has been refactored to 'trainer'. Please update the config.") |
|
|
|
with open_dict(config): |
|
config.output_dir = os.getcwd() |
|
config.logging_dir = os.getcwd() |
|
if config.model.use_kv_cache is False and config.mode == "eval" and config.loader.eval_batch_size > 1: |
|
config.loader.eval_batch_size = max(config.loader.eval_batch_size, 16) |
|
|
|
|
|
if getattr(config.eval, 'txt_img_ratio', None) is not None: |
|
|
|
tot = config.model.length |
|
|
|
|
|
|
|
|
|
if config.eval.txt_img_ratio == 2: |
|
|
|
config.model.txt_length = int(tot * 2/3) |
|
elif config.eval.txt_img_ratio == 1: |
|
config.model.txt_length = int(tot / 2) |
|
elif config.eval.txt_img_ratio == 0.5: |
|
config.model.txt_length = int(tot * 2/3) |
|
elif config.eval.txt_img_ratio == 0.25: |
|
config.model.txt_length = int(tot / 4) |
|
config.model.img_length = tot - config.model.txt_length |
|
config.model.length = config.model.txt_length + config.model.img_length |
|
|
|
|
|
if getattr(config.eval, "varying_seq_len_ratio", False): |
|
assert getattr(config.eval, "sampling_step_ratio", None) is not None, "Must set both varying_seq_len_ratio and sampling_step_ratio" |
|
config.sampling.steps = int(config.model.length * config.eval.sampling_step_ratio) |
|
|
|
if getattr(config.eval, "ablation_config", False): |
|
if config.parameterization == "ar": |
|
rprint(f"WARNING!!!!! FORCING AR PARAMS") |
|
config.trainer.ar_shift = True |
|
config.model.full_attention = False |
|
|
|
config.data.keep_tensordict_on_disk = True |
|
if is_torch_cuda_available(): |
|
if any(x.lower() in torch.cuda.get_device_name().lower() for x in ["v100", "1080", "2080", "quadro", "titan"]) or torch.cuda.get_device_capability()[0] <= 7: |
|
rprint(f"Using 2080Ti/V100, setting precision to fp32") |
|
config.trainer.precision = "no" |
|
config.model.force_optimized_native_attn = False |
|
config.trainer.compile = False |
|
if any(x.lower() in torch.cuda.get_device_name().lower() for x in ["2080", "quadro"]): |
|
config.loader.eval_batch_size = config.loader.eval_batch_size // 7 |
|
config.loader.batch_size = config.loader.batch_size // 7 |
|
elif any(x.lower() in torch.cuda.get_device_name().lower() for x in ["1080", "titan"]): |
|
config.loader.eval_batch_size = config.loader.eval_batch_size // 6 |
|
config.loader.batch_size = config.loader.batch_size // 6 |
|
else: |
|
config.loader.eval_batch_size = config.loader.eval_batch_size // 2 |
|
config.loader.batch_size = config.loader.batch_size // 2 |
|
elif "a5000" in torch.cuda.get_device_name().lower() or "a4500" in torch.cuda.get_device_name().lower(): |
|
config.loader.eval_batch_size = config.loader.eval_batch_size // 2 |
|
config.loader.batch_size = config.loader.batch_size // 2 |
|
else: |
|
rprint(f"Found {torch.cuda.get_device_name()}") |
|
config.loader.eval_batch_size = config.loader.eval_batch_size // 2 |
|
config.loader.batch_size = config.loader.batch_size // 2 |
|
|
|
if getattr(config, "parametierzation", None) == "ar" and config.eval.cfg is not None: |
|
config.loader.eval_batch_size = config.loader.eval_batch_size // 2 |
|
config.loader.batch_size = config.loader.batch_size // 2 |
|
|
|
config.loader.eval_batch_size = max(config.loader.eval_batch_size, 1) |
|
config.loader.batch_size = max(config.loader.batch_size, 1) |
|
|
|
if getattr(config, "parametierzation", None) == "ar": |
|
config.trainer.compile = False |
|
|
|
if getattr(config.sampling, "sampling_step_frac", None) is not None: |
|
config.sampling.steps = int(config.model.length * config.sampling.sampling_step_frac) |
|
rprint(f"Setting sampling steps to {config.sampling.steps}") |
|
|
|
if os.environ.get("SUBMITIT_FOLDER") is not None or os.environ.get("CUSTOM_SBATCH_LAUNCHER", "0") == "1": |
|
rprint(f'Using submitit folder: {os.environ.get("SUBMITIT_FOLDER", "")}, setting slurm=True') |
|
config.slurm = True |
|
|
|
if (config.debug is False or os.environ.get("HYDRA_RUN_DIR_NAME", None) is not None) and torch.distributed.is_torchelastic_launched(): |
|
config.trainer.restart_on_failure = True |
|
rprint(f"Setting restart_on_failure to True") |
|
|
|
if config.trainer.restart_on_failure and config.mode == 'train': |
|
if os.environ.get("HYDRA_RUN_DIR", None) is None and os.environ.get("HYDRA_RUN_DIR_NAME", None) is None: |
|
os.environ["HYDRA_RUN_DIR"] = config.output_dir |
|
rprint(f"Setting HYDRA_RUN_DIR to {os.environ['HYDRA_RUN_DIR']}") |
|
else: |
|
rprint(f"Not setting HYDRA_RUN_DIR, already set to {os.environ.get('HYDRA_RUN_DIR', 'N/A')}, and HYDRA_RUN_DIR_NAME is set to {os.environ.get('HYDRA_RUN_DIR_NAME', 'N/A')}") |
|
|
|
os.environ["RESTART_FAULT_TOLERANT"] = "1" |
|
rprint(f"Setting RESTART_FAULT_TOLERANT to 1") |
|
elif config.trainer.restart_on_failure: |
|
rprint(f"Restart_on_failure is True, but mode is not 'train', so not setting restart fault tolerant") |
|
|
|
relevant_vars = {} |
|
for key, value in os.environ.items(): |
|
if "SLURM" in key or "NCCL" in key or "TORCH" in key: |
|
relevant_vars[key] = value |
|
|
|
config.env_vars = relevant_vars |
|
|
|
if config.trainer.profile_memory: |
|
config.trainer.max_steps = 2 |
|
|
|
if config.debug and config.trainer.force_enable_checkpointing is False and (config.trainer.ckpt_steps is None or config.trainer.ckpt_steps > 0): |
|
config.trainer.ckpt_steps = 10000 |
|
rprint(f"Only checkpointing every {config.trainer.ckpt_steps} steps in debug mode") |
|
|
|
if config.loader.global_batch_size is None: |
|
config.loader.global_batch_size = config.loader.batch_size * config.trainer.accumulate_grad_batches * (1 if is_torch_xla_available() else get_world_size()) |
|
config.loader.eval_global_batch_size = config.loader.global_batch_size |
|
if config.trainer.scale_lr_by_batch_size: |
|
config.optim.lr = config.optim.lr * (config.loader.global_batch_size / 512) |
|
rprint(f"Setting global batch size to {config.loader.global_batch_size}, lr to {config.optim.lr}") |
|
|
|
if config.mode != 'train': |
|
config.checkpointing.resume_wandb = False |
|
config.wandb.resume = None |
|
|
|
if config.trainer.use_spmd_distributed_checkpointing is None: |
|
config.trainer.use_spmd_distributed_checkpointing = is_torch_xla_available() and config.trainer.xla_spmd |
|
|
|
if config.trainer.disable_all_eval_generation: |
|
config.eval.num_masking_viz_batches=0 |
|
config.eval.num_uncond_sample_batches=0 |
|
config.eval.num_sample_batches=0 |
|
config.eval.num_random_masking=0 |
|
config.eval.generate_samples=False |
|
config.trainer.log_flops=False |
|
config.eval.log_every_n_evals=-1 |
|
config.eval.log_every_n_fid = -1 |
|
config.model.image_model_fid_eval = False |
|
rprint("Disabling all eval generation!!!") |
|
|
|
if os.environ.get("XLA_IR_DEBUG", "0") == "1": |
|
config.trainer.tpu_profile = True |
|
|
|
if config.checkpointing_root_dir is not None: |
|
assert "checkpoints" in config.checkpointing.save_dir |
|
relative_path = Path(*Path(config.checkpointing.save_dir).relative_to(config.root_output_dir).parts[1:]) |
|
full_checkpointing_dir = Path(config.checkpointing_root_dir) / relative_path |
|
if config.checkpointing_root_dir is not None: |
|
old_save_dir = Path(config.output_dir) / "checkpoints" |
|
full_checkpointing_dir.mkdir(parents=True, exist_ok=True) |
|
try: |
|
if old_save_dir.exists(): |
|
rprint(f"WARNING: Cannot create symlink from {old_save_dir} to {full_checkpointing_dir} because {old_save_dir} exists.") |
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
old_save_dir = Path(*old_save_dir.parts[:-1]) / f"checkpoints_{timestamp}" |
|
|
|
old_save_dir.symlink_to(full_checkpointing_dir, target_is_directory=True) |
|
rprint(f"Created softlink from {old_save_dir} to {full_checkpointing_dir}") |
|
|
|
|
|
original_link = full_checkpointing_dir.parent / "original_output_dir" |
|
if not original_link.exists(): |
|
original_link.symlink_to(Path(config.output_dir).resolve(), target_is_directory=True) |
|
rprint(f"Created softlink from {original_link} to {config.output_dir}") |
|
else: |
|
rprint(f"WARNING: Symlink {original_link} already exists. Skipping creation.") |
|
|
|
except OSError as e: |
|
rprint(f"Error creating softlinks: {e}") |
|
|
|
assert getattr(config.data, "allow_label", False) == getattr(config.trainer, "add_label", False) == (getattr(config.model, "add_labels", None) is not None) == getattr(config.eval, "class_conditional_fid", False), f"Mismatching values: data.allow_label={config.data.allow_label}, trainer.add_label={config.trainer.add_label}, model.add_labels={config.model.add_labels}, eval.class_conditional_fid={config.eval.class_conditional_fid}" |
|
|
|
if getattr(config.loader, "num_eval_workers", None) is not None and config.loader.num_workers == 0: |
|
rprint(f"Setting num_eval_workers to 0 because num_workers is 0") |
|
config.loader.num_eval_workers = 0 |
|
|
|
if config.trainer.disable_all_checkpointing: |
|
gprint("-"*50) |
|
gprint(f"WARNING: DISABLING ALL CHECKPOINTING!!!!") |
|
gprint("-"*50) |
|
gprint(f"WARNING: DISABLING ALL CHECKPOINTING!!!!") |
|
gprint("-"*50) |
|
config.trainer.ckpt_steps = 100000000 |
|
|
|
if config.sampling.steps != config.sampling.max_sampling_steps: |
|
rprint(f"WARNING!!!! steps {config.sampling.steps} != max_sampling_steps {config.sampling.max_sampling_steps}") |
|
config.sampling.max_sampling_steps = config.sampling.steps |
|
|
|
def get_latest_ckpt(config, input_dir): |
|
if input_dir is None or not Path(input_dir).exists(): |
|
rprint(f"Project dir {input_dir} does not exist") |
|
return None |
|
|
|
if config.trainer.xla_spmd and is_torch_xla_available(): |
|
rprint(f"XLA SPMD detected, using XLA checkpointing") |
|
if any(Path(input_dir).iterdir()): |
|
rprint(f"Found existing files/folders in {input_dir}") |
|
return input_dir |
|
else: |
|
rprint(f"No folders found in {input_dir}") |
|
return None |
|
|
|
folders = [str(folder) for folder in Path(input_dir).iterdir() if folder.is_dir() and ((folder / "model.safetensors").exists() or (folder / "config.yaml").exists())] |
|
|
|
if len(folders) == 0: |
|
rprint(f"No folders found in {input_dir}") |
|
return None |
|
|
|
def _inner(folder): |
|
return list(map(int, re.findall(r"[\/]?([0-9]+)(?=[^\/]*$)", folder)))[0] |
|
|
|
folders.sort(key=_inner) |
|
rprint(f"Found folders: {folders}") |
|
input_dir = folders[-1] |
|
return input_dir |
|
|
|
def is_sweep(): |
|
try: |
|
subdir = HydraConfig.get().sweep.subdir |
|
rprint(f"Found sweep subdir: {subdir}") |
|
return True |
|
except omegaconf.errors.InterpolationToMissingValueError: |
|
return False |
|
|
|
def get_sweep_run_name(config): |
|
try: |
|
subdir = HydraConfig.get().sweep.subdir |
|
sweep_str = f"{subdir}_" |
|
is_sweep = True |
|
except omegaconf.errors.InterpolationToMissingValueError: |
|
is_sweep = False |
|
sweep_str = f"{os.environ.get('SLURM_JOB_ID', '')}_" |
|
|
|
if getattr(config, "training", None) is not None and getattr(getattr(config, "training", None), "force_keys", None) is not None: |
|
rprint("Using legacy keys") |
|
forced_keys = set(config.training.force_keys) |
|
else: |
|
forced_keys = set(getattr(config.trainer, "forced_keys", [])) |
|
|
|
if is_sweep: |
|
print( |
|
f"Getting sweep keys: {HydraConfig.get().job.sweep_keys}, Tasks: {HydraConfig.get().overrides.task}, {getattr(config.trainer, 'forced_keys', [])}" |
|
) |
|
valid_keys = set(HydraConfig.get().job.sweep_keys) |
|
for task in HydraConfig.get().overrides.task: |
|
if task.removeprefix("+").split("=")[0] in valid_keys or task.removeprefix("+").split("=")[0] in forced_keys: |
|
sweep_str += f"{task.removeprefix('+').split('=')[0].split('.')[-1]}={task.removeprefix('+').split('=')[1]}__" |
|
if task.removeprefix("+").split("=")[0] in forced_keys: |
|
forced_keys.remove(task.removeprefix("+").split("=")[0]) |
|
print(f"Forced key: {task.removeprefix('+').split('=')[0]}={task.removeprefix('+').split('=')[1]}") |
|
|
|
for key in sorted(list(forced_keys)): |
|
sweep_str += f"{key.split('.')[-1]}={OmegaConf.select(config, key)}__" |
|
|
|
rprint(f"Sweep: {is_sweep=}, {sweep_str=}") |
|
return "" if sweep_str == "" else sweep_str[:-2] |
|
|
|
def save_config_to_ckpt(config, output_dir, model): |
|
with try_except(write_error_to_file=True, clear_cuda_cache=True): |
|
with read_write(config): |
|
with open_dict(config): |
|
config.state.ckpt_step = model.global_step |
|
config.state.num_evals = model.num_evals |
|
|
|
OmegaConf.save(config=config, f=Path(output_dir) / "config.yaml") |
|
rprint(f"Saved global step {model.global_step}") |
|
|
|
def determine_ckpt(config): |
|
has_recent_ckpt = False |
|
rprint(f"Looking at checkpoint path: {getattr(config.checkpointing, 'resume_ckpt_path', None)}") |
|
if ( |
|
config.checkpointing.resume_from_ckpt |
|
and (latest_ckpt := get_latest_ckpt(config, getattr(config.checkpointing, "resume_ckpt_path", None))) is not None |
|
and (Path(latest_ckpt) / "config.yaml").exists() |
|
): |
|
ckpt_path = latest_ckpt |
|
has_recent_ckpt = True |
|
if config.slurm: |
|
config.wandb.resume = "must" |
|
rprint(f"Resuming from checkpoint {ckpt_path}") |
|
elif config.checkpointing.resume_from_ckpt and getattr(config.checkpointing, "initial_resume_ckpt_path", None) is not None: |
|
ckpt_path = config.checkpointing.initial_resume_ckpt_path |
|
rprint(f"Resuming from initial checkpoint {ckpt_path}") |
|
else: |
|
ckpt_path = None |
|
|
|
if ckpt_path is not None and (config.checkpointing.resume_wandb or has_recent_ckpt): |
|
loaded = OmegaConf.load(Path(ckpt_path) / "config.yaml") |
|
if loaded.wandb.id is not None: |
|
config.wandb.id = str(loaded.wandb.id) |
|
rprint(f"Found wandb id: {config.wandb.id}") |
|
else: |
|
rprint(f"No wandb id found in checkpoint {ckpt_path}") |
|
|
|
if config.checkpointing.resume_wandb and config.wandb.id is not None: |
|
config.wandb.resume = "must" |
|
rprint(f"Resuming wandb, setting must, run id: {config.wandb.id}") |
|
elif config.slurm and config.wandb.id is None: |
|
if os.environ.get("SLURM_ARRAY_TASK_COUNT", "") != "" and int(os.environ.get("SLURM_ARRAY_TASK_COUNT", "")) > 1: |
|
config.wandb.id = str(os.environ.get("SLURM_ARRAY_JOB_ID")) + f"_{os.environ.get('SLURM_ARRAY_TASK_ID')}" |
|
else: |
|
config.wandb.id = str(os.environ.get("SLURM_JOB_ID")) |
|
rprint(f"Setting wandb id to {config.wandb.id}") |
|
|
|
if config.checkpointing.initial_resume_ckpt_path is not None and config.checkpointing.resume_wandb: |
|
assert config.wandb.id is not None |
|
|
|
if config.ckpt is not None: |
|
ckpt_path = config.ckpt |
|
rprint(f"Running eval with checkpoint {ckpt_path}") |
|
|
|
if config.wandb.id is not None: |
|
config.wandb.id = str(config.wandb.id) |
|
|
|
if config.wandb.id is None or getattr(config.trainer, "force_new_wandb_id", False): |
|
config.wandb.id = wandb.util.generate_id() |
|
config.wandb.resume = "allow" |
|
rprint(f"Set wandb id: {config.wandb.id}") |
|
|
|
rprint(f"Using wandb id: {config.wandb.id}") |
|
subdir = get_sweep_run_name(config) |
|
rprint(f"Wandb name: {config.wandb.name}, Wandb subdir: {subdir}") |
|
|
|
if config.wandb.name == 'default': |
|
config.wandb.name = None |
|
else: |
|
config.wandb.name = ( |
|
(f"{config.wandb.name}_" if config.wandb.name else "") |
|
+ (f"{subdir}_" if (subdir is not None and subdir != "") else "") |
|
+ f"{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}" |
|
) |
|
|
|
if getattr(config.wandb, "group", None) is None and subdir is not None and config.debug and os.environ.get("SLURM_ARRAY_JOB_ID", "") != "": |
|
config.wandb.group = os.environ.get("SLURM_ARRAY_JOB_ID") |
|
rprint(f"Wandb group: {config.wandb.group}") |
|
|
|
return ckpt_path |
|
|
|
def run(config, tokenizer): |
|
import torch |
|
from accelerate import (Accelerator, DataLoaderConfiguration, |
|
DDPCommunicationHookType, |
|
DistributedDataParallelKwargs, |
|
FullyShardedDataParallelPlugin) |
|
from accelerate.state import AcceleratorState |
|
from accelerate.utils import GradientAccumulationPlugin, ProjectConfiguration |
|
|
|
set_torch_defaults(config.trainer.benchmark) |
|
|
|
set_env_vars(config) |
|
update_config_before_resolution(config) |
|
ckpt_path = determine_ckpt(config) |
|
OmegaConf.resolve(config) |
|
if is_torch_cuda_available(): |
|
check_gpu_memory_usage() |
|
|
|
if is_torch_cuda_available(): |
|
rprint(f"pt={torch.__version__}, cuda={torch.version.cuda}, nccl={torch.cuda.nccl.version()}") |
|
rprint(f"GPU={torch.cuda.get_device_name()}, device compute capabilities={torch.cuda.get_device_capability()}, pytorch compute capabilities={torch.cuda.get_arch_list()}") |
|
elif is_torch_xla_available(): |
|
rprint(f"XLA Devices={get_tpu_devices()}") |
|
|
|
rprint( |
|
f"Initial GROUP_RANK: {os.environ.get('GROUP_RANK', 'N/A')}, RANK: {os.environ.get('RANK', 'N/A')}, LOCAL_RANK: {os.environ.get('LOCAL_RANK', 'N/A')}, WORLD_SIZE: {os.environ.get('WORLD_SIZE', 'N/A')}, MASTER_ADDR: {os.environ.get('MASTER_ADDR', 'N/A')}, MASTER_PORT: {os.environ.get('MASTER_PORT', 'N/A')}, TORCHELASTIC_RUN_ID: {os.environ.get('TORCHELASTIC_RUN_ID', 'N/A')}, TORCHELASTIC_RESTART_COUNT: {os.environ.get('TORCHELASTIC_RESTART_COUNT', 'N/A')}, TORCHELASTIC_MAX_RESTARTS: {os.environ.get('TORCHELASTIC_MAX_RESTARTS', 'N/A')}, LOCAL_WORLD_SIZE: {os.environ.get('LOCAL_WORLD_SIZE', 'N/A')}, Elastic: {torch.distributed.is_torchelastic_launched()}" |
|
) |
|
rprint(f"Computed Rank: {get_rank()}, Local Rank: {get_local_rank()}, World Size: {get_world_size()}") |
|
|
|
|
|
|
|
sync_timing = (config.trainer.nvtx_profile and getattr(config.trainer, "sync_nvtx_timing", True)) or getattr(config.trainer, "sync_timing", False) |
|
set_timing_builtins(enable=config.trainer.nvtx_profile, sync=sync_timing) |
|
|
|
num_nodes = config.trainer.num_nodes |
|
with open_dict(config): |
|
config.trainer = OmegaConf.merge(config.trainer, dict(mixed_precision=config.trainer.precision, log_with="wandb", log_gradients=None)) |
|
if getattr(config.trainer, "process_dataloader_only", False): |
|
gprint("Processing dataloader only") |
|
train_ds, valid_ds = dataloader.get_dataloaders(config, tokenizer, device="cpu", skip_train=(config.mode == 'eval' and not config.eval.val_with_train_data)) |
|
gprint(f"Exiting after processing dataloader") |
|
return |
|
|
|
accelerator_project_config = ProjectConfiguration( |
|
project_dir=config.output_dir, |
|
logging_dir=config.logging_dir, |
|
automatic_checkpoint_naming=config.checkpointing.use_automatic_naming, |
|
save_on_each_node=False, |
|
) |
|
|
|
accelerate_kwargs = dict() |
|
gradient_kwargs = dict() |
|
if config.trainer.fsdp and not (config.trainer.xla_spmd and is_torch_xla_available()): |
|
rprint("Using FSDP...") |
|
if config.backbone == "llama" or config.backbone == "gemma": |
|
os.environ["ACCELERATE_USE_FSDP"] = "true" |
|
os.environ["FSDP_AUTO_WRAP_POLICY"] = "TRANSFORMER_BASED_WRAP" |
|
os.environ["FSDP_BACKWARD_PREFETCH"] = "NO_PREFETCH" |
|
os.environ["FSDP_CPU_RAM_EFFICIENT_LOADING"] = "true" |
|
os.environ["FSDP_FORWARD_PREFETCH"] = "false" |
|
os.environ["FSDP_OFFLOAD_PARAMS"] = "false" |
|
os.environ["FSDP_SHARDING_STRATEGY"] = "FULL_SHARD" |
|
os.environ["FSDP_STATE_DICT_TYPE"] = "SHARDED_STATE_DICT" |
|
os.environ["FSDP_SYNC_MODULE_STATES"] = "true" |
|
os.environ["FSDP_USE_ORIG_PARAMS"] = "true" |
|
fsdp_plugin = FullyShardedDataParallelPlugin() |
|
else: |
|
os.environ["ACCELERATE_USE_FSDP"] = "true" |
|
os.environ["FSDP_AUTO_WRAP_POLICY"] = "TRANSFORMER_BASED_WRAP" |
|
if config.backbone == "elm": |
|
os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = "OpenELMDecoderLayer" |
|
os.environ["FSDP_BACKWARD_PREFETCH"] = "BACKWARD_PRE" |
|
os.environ["FSDP_SHARDING_STRATEGY"] = "HYBRID_SHARD_ZERO2" |
|
else: |
|
|
|
os.environ["FSDP_BACKWARD_PREFETCH"] = "BACKWARD_PRE" |
|
|
|
os.environ["FSDP_SHARDING_STRATEGY"] = "HYBRID_SHARD_ZERO2" |
|
os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = "DDiTBlock" |
|
|
|
|
|
from torch.distributed.fsdp.fully_sharded_data_parallel import ( |
|
FullOptimStateDictConfig, FullStateDictConfig) |
|
fsdp_plugin = FullyShardedDataParallelPlugin( |
|
state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True), |
|
optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), |
|
) |
|
|
|
if config.trainer.compile or config.trainer.use_orig_params is True: |
|
|
|
fsdp_plugin.use_orig_params = True |
|
rprint("Using orig params for FSDP. This is required for torch.compile to work.") |
|
|
|
accelerate_kwargs["fsdp_plugin"] = fsdp_plugin |
|
gradient_kwargs["sync_each_batch"] = False |
|
|
|
if getattr(config.trainer, "fsdp_sync_each_batch", False): |
|
rprint("Using sync each batch for Chameleon") |
|
gradient_kwargs["sync_each_batch"] = True |
|
|
|
elif config.trainer.xla_spmd is False: |
|
rprint("Using DDP...") |
|
ddp_kwargs = DistributedDataParallelKwargs( |
|
find_unused_parameters=config.trainer.find_unused_parameters, |
|
comm_hook=DDPCommunicationHookType.BF16, |
|
static_graph=config.trainer.accumulate_grad_batches == 1, |
|
gradient_as_bucket_view=True, |
|
) |
|
|
|
|
|
|
|
from datetime import timedelta |
|
|
|
from accelerate.utils import InitProcessGroupKwargs |
|
init_process_group_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=1800)) |
|
accelerate_kwargs["kwargs_handlers"] = [ddp_kwargs, init_process_group_kwargs] |
|
else: |
|
rprint(f"Did not choose DDP or FSDP.") |
|
|
|
if config.trainer.accumulate_grad_batches <= 0: |
|
gprint("WARNING!!!!!! Accumulate grad batches is <= 0, setting to 1") |
|
config.trainer.accumulate_grad_batches = 1 |
|
|
|
gradient_accumulation_plugin = GradientAccumulationPlugin( |
|
num_steps=config.trainer.accumulate_grad_batches, |
|
adjust_scheduler=False, |
|
sync_with_dataloader=False, |
|
**gradient_kwargs |
|
) |
|
|
|
if config.trainer.mixed_precision == "bf16" and (is_torch_cuda_available() and not torch.cuda.is_bf16_supported()): |
|
rprint(f"No BF16 GPU found, falling back to FP16") |
|
config.trainer.mixed_precision = "fp16" |
|
|
|
if config.trainer.mixed_precision == "fp32": |
|
config.trainer.mixed_precision = "no" |
|
else: |
|
if is_torch_xla_available(): |
|
os.environ["ACCELERATE_DOWNCAST_BF16"] = "true" |
|
|
|
rprint(f"Mixed precision: {config.trainer.mixed_precision}") |
|
|
|
if config.seed is None or getattr(config.eval, 'set_random_gen_seed', False): |
|
|
|
accelerate_kwargs["rng_types"] = [] |
|
rprint("No seed provided, disabling accelerate RNG synchronization") |
|
|
|
accelerator = Accelerator( |
|
mixed_precision=config.trainer.mixed_precision, |
|
log_with=config.trainer.log_with, |
|
project_config=accelerator_project_config, |
|
gradient_accumulation_plugin=gradient_accumulation_plugin, |
|
dataloader_config=DataLoaderConfiguration(split_batches=False, dispatch_batches=False, non_blocking=False), |
|
**accelerate_kwargs, |
|
) |
|
|
|
gprint(f"Distributed Type: {accelerator.distributed_type}, Accelerator state: {AcceleratorState()}") |
|
num_processes = AcceleratorState().num_processes |
|
if getattr(config.trainer, "global_num_warmup_steps", None) is not None: |
|
rprint(f"Global num_warmup_steps was: {config.lr_scheduler.num_warmup_steps}. Applying to num_warmup_steps") |
|
config.lr_scheduler.num_warmup_steps = config.trainer.global_num_warmup_steps |
|
|
|
if getattr(config.trainer, "global_num_training_steps", None) is not None: |
|
rprint(f"Global num_training_steps was: {config.lr_scheduler.num_training_steps}. Applying to num_training_steps") |
|
config.lr_scheduler.num_training_steps = config.trainer.global_num_training_steps |
|
|
|
if not config.trainer.disable_adjust_num_warmup_steps: |
|
rprint(f"Original num_warmup_steps was: {config.lr_scheduler.num_warmup_steps}") |
|
config.lr_scheduler.num_warmup_steps = config.lr_scheduler.num_warmup_steps * num_processes |
|
rprint(f"Setting num_warmup_steps to: {config.lr_scheduler.num_warmup_steps}") |
|
|
|
if hasattr(config.lr_scheduler, "num_training_steps"): |
|
rprint(f"Original num_training_steps was: {config.lr_scheduler.num_training_steps}") |
|
config.lr_scheduler.num_training_steps = config.lr_scheduler.num_training_steps * num_processes |
|
rprint(f"Setting num_training_steps to: {config.lr_scheduler.num_training_steps}") |
|
|
|
assert config.trainer.allow_dynamic_nodes or (os.environ.get("XLA_USE_SPMD", "0") == "1") or accelerator.num_processes == ( |
|
config.trainer.devices * num_nodes |
|
), f"Expected {config.trainer.devices * num_nodes} GPUs but got {accelerator.num_processes} processes." |
|
|
|
compute_dtyle = torch.float32 |
|
if accelerator.mixed_precision == "fp16": |
|
compute_dtyle = torch.float16 |
|
elif accelerator.mixed_precision == "bf16": |
|
compute_dtyle = torch.bfloat16 |
|
|
|
if compute_dtyle != torch.bfloat16: |
|
rprint(f"WARNING!!!! Compute dtype is: {compute_dtyle}") |
|
else: |
|
rprint(f"Compute dtype is: {compute_dtyle}") |
|
|
|
if is_main_process(): |
|
instantiate_wandb(config, accelerator) |
|
|
|
run_cmd = get_run_cmd(config) |
|
with open_dict(config): |
|
config.trainer.devices = accelerator.num_processes |
|
config.trainer.dtype = str(compute_dtyle) |
|
if hasattr(config, "state"): |
|
config.state.cmd = run_cmd |
|
else: |
|
config.state = OmegaConf.create(dict(cmd=run_cmd)) |
|
|
|
OmegaConf.set_readonly(config, True) |
|
|
|
if getattr(config.trainer, "attach_oom_observer", False): |
|
from torchtnt.utils.oom import attach_oom_observer |
|
attach_oom_observer(output_dir=str(os.getcwd()), trace_max_entries=500000) |
|
rprint(f"Attached OOM observer to {os.getcwd()}") |
|
train_ds, valid_ds = dataloader.get_dataloaders(config, tokenizer, device=accelerator.device, skip_train=(config.mode == 'eval' and not config.eval.val_with_train_data)) |
|
model = Diffusion(config=config, tokenizer=valid_ds.tokenizer, device=accelerator.device) |
|
|
|
if is_main_process(): |
|
print_params(model.backbone) |
|
|
|
try: |
|
if getattr(config.model, "image_model", False) is False: |
|
_print_batch(train_ds, valid_ds, tokenizer) |
|
except: |
|
pass |
|
|
|
get_ema_path = lambda x: Path(x) / "ema.ckpt" |
|
SAMPLER_NAME = "weighted_dataset_sampler" |
|
|
|
def save_model_hook(models, weights, output_dir): |
|
nonlocal model, accelerator, train_ds |
|
|
|
if is_main_process(): |
|
with try_except(write_error_to_file=True): |
|
if getattr(model, "ema", None) is not None: |
|
torch.save(accelerator.unwrap_model(model).ema.state_dict(), get_ema_path(output_dir)) |
|
rprint(f"Saved EMA to {get_ema_path(output_dir)}") |
|
|
|
save_config_to_ckpt(config, output_dir, model) |
|
|
|
with try_except(write_error_to_file=True): |
|
if config.data.use_weighted_tensordict_sampler: |
|
from accelerate.utils import save |
|
output_sampler_file = output_dir.joinpath(f"{SAMPLER_NAME}_train.bin") |
|
save(train_ds.sampler.state_dict(), output_sampler_file, save_on_each_node=False, safe_serialization=False) |
|
rprint(f"Sampler state for dataloader saved in {output_sampler_file}") |
|
|
|
initial_global_step = None |
|
def load_model_hook(models, input_dir): |
|
nonlocal initial_global_step, model, train_ds |
|
config_path = os.path.join(input_dir, "config.yaml") |
|
ckpt_config = OmegaConf.load(config_path) |
|
initial_global_step = ckpt_config.state.ckpt_step |
|
model.global_step = initial_global_step |
|
try: |
|
if hasattr(config.state, "num_evals"): |
|
model.num_evals = config.state.num_evals |
|
except Exception as e: |
|
rprint(f"Error loading model: {e}") |
|
rprint(f"Loaded global step {initial_global_step}") |
|
|
|
state_dict = None |
|
if getattr(config.checkpointing, "load_from_old_attention_format", False): |
|
state_dict = load_file(os.path.join(input_dir, "model.safetensors")) |
|
state_dict = convert_state_dict_keys(state_dict) |
|
|
|
if getattr(model, "ema", None) is not None: |
|
if get_ema_path(input_dir).exists(): |
|
rprint(f"Loading EMA from {get_ema_path(input_dir)}") |
|
model.ema.load_state_dict(torch.load(get_ema_path(input_dir), map_location='cpu')) |
|
else: |
|
rprint(f"No EMA found, initializing EMA with state_dict") |
|
if state_dict is None: |
|
state_dict = load_file(os.path.join(input_dir, "model.safetensors")) |
|
|
|
|
|
accelerator.unwrap_model(models[0]).load_state_dict(state_dict) |
|
from models.ema import EMAModel |
|
model.ema = EMAModel(accelerator.unwrap_model(models[0]).parameters(), decay=config.trainer.ema) |
|
|
|
if config.data.use_weighted_tensordict_sampler and not is_torch_xla_available(): |
|
input_sampler_file = Path(input_dir).joinpath(f"{SAMPLER_NAME}_train.bin") |
|
if train_ds is not None and input_sampler_file.exists(): |
|
train_ds.sampler.load_state_dict(torch.load(input_sampler_file)) |
|
rprint("All dataloader sampler states loaded successfully") |
|
|
|
accelerator.register_save_state_pre_hook(save_model_hook) |
|
accelerator.register_load_state_pre_hook(load_model_hook) |
|
model.init_dataloader(train_ds, valid_ds) |
|
model.set_accelerator(accelerator, ckpt_path) |
|
model.set_callbacks() |
|
|
|
if getattr(config.checkpointing, "load_from_text_model", None) is not None: |
|
rprint(f"Loading from text model") |
|
model.custom_load_checkpoint() |
|
|
|
if getattr(config.checkpointing, "load_from_lightning_ckpt", None) is not None: |
|
ckpt = torch.load(config.checkpointing.load_from_lightning_ckpt) |
|
initial_global_step = ckpt["global_step"] |
|
state_dict_ = {k.removeprefix("backbone."): v for k, v in ckpt["state_dict"].items() if "backbone" in k} |
|
state_dict_ = {k.replace(".attn_", ".attention.attn_"): v for k, v in state_dict_.items()} |
|
accelerator.unwrap_model(model.backbone).load_state_dict(state_dict_) |
|
|
|
if config.trainer.ema > 0: |
|
model.ema.load_state_dict(ckpt["ema"]) |
|
|
|
rprint(f"Loaded lightning ckpt: {config.checkpointing.load_from_lightning_ckpt}") |
|
|
|
if initial_global_step is not None: |
|
|
|
model.global_step = initial_global_step |
|
rprint(f"Set global step to {initial_global_step}") |
|
|
|
contexts = [] |
|
if config.trainer.nvtx_profile: |
|
contexts.append(torch.autograd.profiler.emit_nvtx(record_shapes=True)) |
|
|
|
if config.trainer.profile_memory: |
|
contexts.append(profile_memory()) |
|
|
|
using_torch_elastic = torch.distributed.is_torchelastic_launched() |
|
if using_torch_elastic: |
|
rprint(f"Torchelastic launched: {torch.distributed.is_torchelastic_launched()}") |
|
contexts.append(ErrorHandler()) |
|
|
|
with ExitStack() as stack: |
|
for ctx in contexts: |
|
stack.enter_context(ctx) |
|
|
|
rprint(f"output_dir: {config.output_dir}") |
|
model.to(accelerator.device) |
|
if config.mode == 'train': |
|
model.train() |
|
elif config.mode == 'eval': |
|
if config.eval.standalone_fid: |
|
model.validate(None) |
|
else: |
|
model.validate(None) |
|
elif config.mode == 'zero-shot-eval': |
|
model.zero_shot_eval() |
|
else: |
|
raise ValueError(f"Invalid mode: {config.mode}") |
|
|
|
accelerator.end_training() |
|
|
|
|
|
def get_run_cmd(config): |
|
orig_argv = deepcopy(sys.argv) |
|
|
|
prepend_argv = [] |
|
if "HYDRA_RUN_DIR" in os.environ: |
|
prepend_argv.append(f"HYDRA_RUN_DIR='{os.environ['HYDRA_RUN_DIR']}'") |
|
else: |
|
prepend_argv.append(f"HYDRA_RUN_DIR='{str(Path(config.output_dir).resolve())}'") |
|
|
|
if orig_argv[1].startswith("experiments=["): |
|
orig_argv[1] = orig_argv[1].removeprefix("experiments=[").removesuffix("]") |
|
orig_argv[1] = f"experiments=\'[{orig_argv[1]}]\'" |
|
|
|
if os.environ.get("CUSTOM_SBATCH_LAUNCHER", "0") == "1": |
|
sbatch_script_path = 'scripts/slurm.sh' |
|
orig_argv.pop(0) |
|
orig_argv = ["sbatch", f"--nodes={os.environ.get('SLURM_NNODES', '1')}", f"--gpus-per-node={os.environ.get('SLURM_GPUS_PER_NODE', '1')}", f"--partition={os.environ.get('SLURM_JOB_PARTITION', 'all')}", sbatch_script_path] + orig_argv |
|
else: |
|
prepend_argv.append("accelerate launch") |
|
|
|
full_run_cmd = " ".join(prepend_argv + orig_argv) |
|
rprint(f"Full run cmd: {full_run_cmd}") |
|
return full_run_cmd |
|
|
|
@hydra.main(version_base=None, config_path=CONFIG_PATH, config_name="config") |
|
@try_except() |
|
def main(config): |
|
if is_sweep(): |
|
print(f"Checking if we need to requeue for job id {os.environ['SLURM_JOB_ID']}") |
|
from unidisc.utils.slurm_requeue import check_requeue |
|
check_requeue() |
|
print(f"Done checking if we need to requeue for job id {os.environ['SLURM_JOB_ID']}.") |
|
|
|
"""Main entry point for trainer.""" |
|
import torch |
|
if is_torch_xla_available(): |
|
builtins.HAS_XLA_SPAWNED = True |
|
os.environ['PJRT_DEVICE'] = 'TPU' |
|
|
|
if config.trainer.precision == "bf16": |
|
os.environ['XLA_USE_BF16'] = '1' |
|
|
|
if config.devices == 1 and config.trainer.xla_spmd is False and config.trainer.fsdp is False: |
|
os.environ['TPU_PROCESS_BOUNDS'] = '1,1,1' |
|
os.environ['TPU_VISIBLE_CHIPS'] = '0' |
|
gprint(f"Setting TPU_PROCESS_BOUNDS: {os.environ['TPU_PROCESS_BOUNDS']}") |
|
gprint(f"Setting TPU_VISIBLE_CHIPS: {os.environ['TPU_VISIBLE_CHIPS']}") |
|
|
|
if config.trainer.tpu_eager: |
|
os.environ['XLA_USE_EAGER_DEBUG_MODE'] = '1' |
|
|
|
if config.trainer.tpu_compile_debug: |
|
os.environ['PT_XLA_DEBUG'] = '1' |
|
os.environ['PT_XLA_DEBUG_LEVEL'] = '2' |
|
os.environ['XLA_IR_DEBUG'] = '1' |
|
os.environ['XLA_HLO_DEBUG'] = '1' |
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0' |
|
os.environ['TF_CPP_VMODULE'] = 'xla_graph_executor=5,pjrt_computation_client=3' |
|
|
|
|
|
spmd_mesh, axis_names, num_nodes = None, None, None |
|
if config.trainer.xla_spmd: |
|
import torch_xla.core.xla_model as xm |
|
import torch_xla.distributed.spmd as xs |
|
import torch_xla.runtime as xr |
|
from accelerate import PartialState |
|
from torch_xla._internal import tpu |
|
auto_spmd = getattr(config.trainer, "auto_spmd", False) |
|
|
|
xr.use_spmd(auto=auto_spmd) |
|
force_global_devices = getattr(config.trainer, "force_global_devices", None) |
|
force_local_devices = getattr(config.trainer, "force_local_devices", None) |
|
assert (force_global_devices is None) == (force_local_devices is None), "Must set both or neither" |
|
|
|
if force_global_devices is not None: |
|
num_global_devices = force_global_devices |
|
num_local_devices = force_local_devices |
|
gprint(f"Using force global devices: num_global_devices={num_global_devices}, num_local_devices={num_local_devices}") |
|
else: |
|
num_global_devices = xr.global_runtime_device_count() |
|
num_local_devices = tpu.num_available_devices() |
|
assert num_global_devices == tpu.num_expected_global_devices() |
|
assert tpu.num_available_devices() == tpu.num_available_chips() == tpu.num_local_processes() |
|
|
|
num_nodes = num_global_devices // num_local_devices |
|
spmd_mesh_shape = getattr(config.trainer, "spmd_mesh", None) |
|
if spmd_mesh_shape is None: |
|
spmd_mesh_shape = (num_nodes, num_local_devices, 1) |
|
|
|
if getattr(config.trainer, "force_disable_replicas", False): |
|
spmd_mesh_shape = (1, num_global_devices, 1) |
|
rprint(f"Forcing disable replicas: {spmd_mesh_shape}") |
|
|
|
if auto_spmd is False: |
|
if getattr(config.trainer, "spmd_multislice", None) is not None: |
|
from torch_xla.distributed.spmd import HybridMesh |
|
ici_mesh_shape = spmd_mesh_shape |
|
dcn_mesh_shape = (config.trainer.spmd_multislice, 1, 1) |
|
spmd_mesh = HybridMesh(ici_mesh_shape=ici_mesh_shape, dcn_mesh_shape=dcn_mesh_shape, axis_names=('data','fsdp','tensor')) |
|
rprint(f"Using multislice: {config.trainer.spmd_multislice}: {ici_mesh_shape} {dcn_mesh_shape}, {spmd_mesh.shape()}") |
|
else: |
|
spmd_mesh = xs.Mesh(np.array(range(num_global_devices)), spmd_mesh_shape, ('dcn', 'fsdp', 'model')) |
|
xs.set_global_mesh(spmd_mesh) |
|
|
|
config.devices = 1 |
|
config.nodes = 1 |
|
|
|
with read_write(config): |
|
with open_dict(config): |
|
config.state = OmegaConf.create(dict(spmd_mesh=spmd_mesh_shape)) |
|
config.state.axis_names = axis_names |
|
config.state.num_nodes = num_nodes |
|
config.state.num_global_devices = num_global_devices |
|
config.state.num_local_devices = num_local_devices |
|
config.state.worker_ips = tpu.get_worker_ips() |
|
if os.environ.get("TPU_NAME") is not None: |
|
config.state.tpu_name = os.environ.get("TPU_NAME") |
|
|
|
if config.trainer.tpu_eager: |
|
import torch_xla |
|
torch_xla.experimental.eager_mode(True) |
|
|
|
if config.trainer.tpu_profile: |
|
if config.trainer.tpu_profile_markers: |
|
os.environ['XLA_IR_DEBUG'] = '1' |
|
os.environ['XLA_HLO_DEBUG'] = '1' |
|
import torch_xla.debug.profiler as xp |
|
server = xp.start_server(9012) |
|
|
|
if config.trainer.tpu_cache: |
|
import torch_xla.runtime as xr |
|
readonly = not is_main_process() |
|
rprint(f"Initializing TPU cache with readonly={readonly}") |
|
xr.initialize_cache(str((Path.home() / '.cache' / 'unidisc' / f"tpu_{get_rank()}_{get_hostname().replace('-', '_')}").resolve()), readonly=readonly) |
|
|
|
if config.trainer.enable_jax_smi: |
|
initialise_tracking() |
|
rprint("Initializing jax-smi") |
|
|
|
from unidisc.utils.logging_utils import set_logger |
|
set_logger(f"{__name__} {get_slurm_log_prefix()}", Path(f"{get_slurm_filename_info()}_{get_hostname().replace('-', '_')}.out")) |
|
|
|
if is_torch_xla_available(): |
|
import torch_xla.runtime as xr |
|
gprint( |
|
f"Computed Rank: {get_rank()}, " |
|
f"Is Main Process: {is_main_process()}, " |
|
f"Is Local Main Process: {is_local_main_process()}, " |
|
f"XLA world size: {xr.world_size()}, " |
|
f"XLA Model Ordinal: {xm.get_ordinal()}, " |
|
f"XLA Global Ordinal: {xr.global_ordinal()}, " |
|
f"XLA Supported Devices: {xm.get_xla_supported_devices()}, " |
|
f"Accelerate Local Process Index: {PartialState().local_process_index}, " |
|
f"Task ID: {tpu.task_id()}, " |
|
f"Worker ID: {tpu.worker_id()} " |
|
f"global device count: {xr.global_runtime_device_count()}, " |
|
f"local process count: {xr.local_process_count()}, " |
|
f"local device count: {xr.local_device_count()}, " |
|
f"addressable device count: {xr.addressable_device_count()}, " |
|
f"num_expected_global_devices: {tpu.num_expected_global_devices()}, " |
|
f"num_available_devices: {tpu.num_available_devices()}, " |
|
f"num_available_chips: {tpu.num_available_chips()}, " |
|
f"num_local_processes: {tpu.num_local_processes()}, " |
|
f"process_bounds_size: {tpu.process_bounds_size()}, " |
|
f"get_worker_ips: {tpu.get_worker_ips()}, " |
|
f"Computed Num Nodes: {num_nodes}, " |
|
f"Specified Mesh: {spmd_mesh_shape}, " |
|
f"Specified Mesh Axes: {axis_names}" |
|
) |
|
|
|
gprint(f"LIBTPU_INIT_ARGS: {os.environ.get('LIBTPU_INIT_ARGS', 'None')}") |
|
gprint(f"XLA_FLAGS: {os.environ.get('XLA_FLAGS', 'None')}") |
|
|
|
if getattr(config.trainer, "disable_ddp_optimizer", False): |
|
torch._dynamo.config.optimize_ddp = False |
|
|
|
if config.seed is not None: |
|
if config.mode == 'eval': |
|
config.seed = config.seed + 1000 * int(get_rank()) |
|
else: |
|
config.seed = config.seed + int(get_rank()) |
|
np.random.seed(config.seed) |
|
random.seed(config.seed) |
|
torch.manual_seed(config.seed) |
|
if is_torch_cuda_available(): |
|
|
|
torch.cuda.manual_seed_all(config.seed) |
|
|
|
if is_torch_xla_available(): |
|
import torch_xla.core.xla_model as xm |
|
xm.set_rng_state(config.seed) |
|
gprint(f"Set seed: {config.seed}") |
|
else: |
|
rprint("No seed provided") |
|
|
|
_print_config(config, resolve=True, save_cfg=True) |
|
|
|
with open(f"env_vars_{get_slurm_filename_info()}_{get_hostname().replace('-', '_')}.txt", "w") as f: |
|
for key, value in os.environ.items(): |
|
f.write(f"{key}={value}\n") |
|
|
|
tokenizer = dataloader.get_tokenizer(config) |
|
|
|
if "tokens" in config.data.train and (config.loader.num_workers > 0 or getattr(config.data, "force_mp_spawn", False)): |
|
from torch import multiprocessing as mp |
|
try: |
|
rprint(f"Start already method set to: {mp.get_start_method()}") |
|
except: |
|
mp.set_start_method("spawn") |
|
rprint(f"Start method set to: {mp.get_start_method()}") |
|
|
|
rprint(f"Mode: {config.mode}") |
|
if config.mode == "sample_eval": |
|
generate_samples(config, tokenizer) |
|
else: |
|
try: |
|
run(config, tokenizer) |
|
except Exception as e: |
|
rprint(f"Traceback: {traceback.format_exc()}") |
|
rprint(f"Exception: {e}") |
|
|
|
timestamp = int(__import__("time").time_ns()) |
|
error_filepath = f"exception_{timestamp}_{process_file_prefix()}.out" |
|
with open(error_filepath, "w") as file: |
|
file.write(traceback.format_exc()) |
|
rprint(f"See error file {Path(error_filepath).resolve()} for traceback") |
|
|
|
if is_torch_xla_available(): |
|
exit(1) |
|
|
|
if ("SLURM_JOB_ID" not in os.environ) and ("RESTART_FAULT_TOLERANT" not in os.environ) and not is_torch_xla_available(): |
|
gprint(f"Entering debugger") |
|
breakpoint(traceback=e.__traceback__) |
|
else: |
|
rprint(f"Not breaking, SLURM_JOB_ID: {os.environ.get('SLURM_JOB_ID')}, RESTART_FAULT_TOLERANT: {os.environ.get('RESTART_FAULT_TOLERANT')}") |
|
|
|
if "RESTART_FAULT_TOLERANT" in os.environ: |
|
sigterm_handler = signal.getsignal(signal.SIGTERM) |
|
if callable(sigterm_handler): |
|
rprint(f"Calling SIGTERM handler") |
|
sigterm_handler(signal.SIGTERM, None) |
|
|
|
try: |
|
if config.trainer.num_nodes > 1 and config.debug is False and is_main_process(): |
|
wandb.alert(title="Exception!", text=f"{e}, {traceback.format_exc()}") |
|
except: |
|
pass |
|
raise e |
|
finally: |
|
pass |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|