unidisc / main.py
aswerdlow's picture
Initial commit
131da64
import os
import sys
from contextlib import ExitStack
from pathlib import Path
from constants import CONFIG_PATH, LIB_DIR
sys.path.append(str(LIB_DIR / "hydra_submitit_launcher"))
import builtins
import random
import re
import signal
import traceback
from copy import deepcopy
from datetime import datetime
import hydra
import numpy as np
import omegaconf
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf, open_dict, read_write
from safetensors.torch import load_file, save_file
import dataloader
from model import Diffusion
import utils
import wandb
from decoupled_utils import (check_gpu_memory_usage, get_hostname,
get_local_rank, get_rank, get_slurm_filename_info,
get_slurm_log_prefix, get_tpu_devices,
get_world_size, gprint, is_local_main_process,
is_main_process, is_torch_cuda_available,
is_torch_xla_available, print_params,
process_file_prefix, profile_memory, rank_zero_fn,
rprint, set_global_breakpoint, set_global_exists,
set_timing_builtins, try_except)
from utils import (ErrorHandler, _print_config, convert_state_dict_keys, set_omega_conf_resolvers, set_torch_defaults)
# Only needed when debugging hydra
# os.environ["HYDRA_FULL_ERROR"] = "1"
set_global_breakpoint() # Overrides breakpoint() to use ipdb.set_trace() instead and handle distributed training
set_global_exists()
set_omega_conf_resolvers()
if is_torch_xla_available():
from jax_smi import initialise_tracking
def _load_from_checkpoint(config, tokenizer):
OmegaConf.resolve(config)
if "hf" in config.backbone:
return Diffusion(config=config, tokenizer=tokenizer).to("cuda")
return Diffusion.load_from_checkpoint(config.eval.checkpoint_path, tokenizer=tokenizer, config=config)
@rank_zero_fn
def _print_batch(train_ds, valid_ds, tokenizer, k=64):
for dl_type, dl in [("train", train_ds), ("valid", valid_ds)]:
rprint(f"Printing {dl_type} dataloader batch.")
batch = next(iter(dl))
rprint("Batch input_ids.shape", batch["input_ids"].shape)
first = batch["input_ids"][0, :k]
last = batch["input_ids"][0, -k:]
rprint(f"First {k} tokens:", tokenizer.decode(first))
rprint("ids:", first)
rprint(f"Last {k} tokens:", tokenizer.decode(last))
rprint("ids:", last)
def generate_samples(config, tokenizer):
rprint("Generating samples.")
model = _load_from_checkpoint(config=config, tokenizer=tokenizer)
model.gen_ppl_metric.reset()
if config.eval.disable_ema:
rprint("Disabling EMA.")
model.ema = None
stride_length = config.sampling.stride_length
num_strides = config.sampling.num_strides
for _ in range(config.sampling.num_sample_batches):
if config.sampling.semi_ar:
_, intermediate_samples, _ = model.restore_model_and_semi_ar_sample(
stride_length=stride_length, num_strides=num_strides, dt=1 / config.sampling.steps
)
text_samples = intermediate_samples[-1]
# Note: Samples generated using semi-ar method
# need to to be processed before computing generative perplexity
# since these samples contain numerous <|endoftext|> tokens
# and diffusion.compute_generative_perplexity() discards
# any text after the first EOS token.
else:
samples = model.restore_model_and_sample(num_steps=config.sampling.steps)
text_samples = model.tokenizer.batch_decode(samples)
model.compute_generative_perplexity(text_samples)
rprint("Text samples:", text_samples)
if not config.sampling.semi_ar:
rprint("Generative perplexity:", model.gen_ppl_metric.compute())
return text_samples
def instantiate_wandb(config, accelerator):
if is_torch_xla_available():
gprint("Initializing wandb for XLA")
if config.mode == 'eval':
config.wandb.project = f"{config.wandb.project}-eval"
elif config.mode == 'zero-shot-eval':
config.wandb.project = f"{config.wandb.project}-zero-shot-eval"
if config.wandb.group is not None:
config.wandb.group = str(config.wandb.group)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
wandb_kwargs = dict(config.wandb)
if getattr(config, "sweep_id", None) is not None:
rprint(f"Setting Wandb group to {config.sweep_id}")
wandb_kwargs["group"] = config.sweep_id
del wandb_kwargs["project"]
accelerator.init_trackers(
config.wandb.project, config=OmegaConf.to_container(config, resolve=True, throw_on_missing=True), init_kwargs=dict(wandb=wandb_kwargs)
)
if getattr(config.trainer, "log_code", True) and is_main_process():
if "matrix" in get_hostname():
rprint(f"Not logging code to wandb on {get_hostname()}")
else:
rprint(f"Logging code to wandb from {Path(__file__).parent}")
try:
wandb.run.log_code(
root=str(Path(__file__).parent),
include_fn=lambda path: any(path.endswith(f) for f in (".py", ".yaml", ".yml", ".txt", ".md")),
exclude_fn=lambda path, root: any(x in os.path.relpath(path, root) for x in ("output", "multirun", "logs", "wandb")),
)
except Exception as e:
rprint(f"Failed to log code to wandb: {e}")
with open_dict(config):
try:
config.wandb_url = wandb.run.get_url()
wandb.define_metric("global_samples")
wandb.define_metric("effective_global_tokens")
wandb.define_metric("effective_global_step")
wandb.define_metric("train_metrics/samples")
wandb.define_metric("trainer/loss", step_metric="global_samples")
except Exception as e:
rprint(f"Failed to get wandb url: {e}")
def instantiate_model(config, tokenizer):
model = _load_from_checkpoint(config=config, tokenizer=tokenizer)
if config.eval.disable_ema:
rprint("Disabling EMA.")
model.ema = None
return model
def gconf(config, attr):
return getattr(config, attr, None)
def has_ckpt(config, attr):
return gconf(config, attr) is not None and utils.fsspec_exists(gconf(config, attr))
def set_env_vars(config):
import torch
hostname = __import__("socket").gethostname()
rprint(f"Starting Training on {hostname}")
import torch
# os.environ["TORCHINDUCTOR_CACHE_DIR"] = str((Path.home() / ".cache" / "torchinductor").resolve())
if not is_torch_xla_available():
try:
# Applies the equivalent of ulimit -l unlimited to this process [and children]
# This caused a significant amount of pain to figure out
import resource
soft, hard = resource.getrlimit(resource.RLIMIT_MEMLOCK)
resource.setrlimit(resource.RLIMIT_MEMLOCK, (hard, hard))
if is_local_main_process():
gprint(f"Successfully set RLIMIT_MEMLOCK to {hard}")
except ValueError as e:
rprint(f"Failed to set RLIMIT_MEMLOCK: {e}")
except resource.error as e:
rprint(f"Error setting RLIMIT_MEMLOCK: {e}")
else:
rprint(f"Not setting RLIMIT_MEMLOCK on XLA")
if "matrix-3-28" in hostname or "matrix-3-26" in hostname:
rprint(f"Disabling NCCL P2P")
os.environ["NCCL_P2P_DISABLE"] = "1"
if os.environ.get("TORCH_DISTRIBUTED_DEBUG", "") != "":
assert False, f"TORCH_DISTRIBUTED_DEBUG is set to: {os.environ.get('TORCH_DISTRIBUTED_DEBUG')}. Please unset it as it starts a gloo backend."
if config.model.use_spda_attn:
os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "1"
os.environ["TORCH_CUDNN_MHA_ENABLED"] = "1"
rprint("Setting SPDA Flags")
if config.trainer.detect_anomaly:
torch.autograd.set_detect_anomaly(True)
def update_config_before_resolution(config):
import torch
if hasattr(config, "training"):
rprint(f"'training' has been refactored to 'trainer'. Please update the config.")
with open_dict(config):
config.output_dir = os.getcwd()
config.logging_dir = os.getcwd()
if config.model.use_kv_cache is False and config.mode == "eval" and config.loader.eval_batch_size > 1:
config.loader.eval_batch_size = max(config.loader.eval_batch_size, 16)
# todo revert?
if getattr(config.eval, 'txt_img_ratio', None) is not None:
# 2,1,0.5,0.25
tot = config.model.length
# if its 2:1, then distribute the tokens as 2/3, 1/3
# if its 1:1, then distribute the tokens as 1/2, 1/2
# if its 0.5:1, then distribute the tokens as 2/3, 1/3
# if its 0.25:1, then distribute the tokens as 1/4, 3/4
if config.eval.txt_img_ratio == 2:
# do first 2/3 tokens as text, last 1/3 as image
config.model.txt_length = int(tot * 2/3)
elif config.eval.txt_img_ratio == 1:
config.model.txt_length = int(tot / 2)
elif config.eval.txt_img_ratio == 0.5:
config.model.txt_length = int(tot * 2/3)
elif config.eval.txt_img_ratio == 0.25:
config.model.txt_length = int(tot / 4)
config.model.img_length = tot - config.model.txt_length
config.model.length = config.model.txt_length + config.model.img_length
# config.eval.attention_caching_txt_to_img_ratio = config.model.txt_length // 20
if getattr(config.eval, "varying_seq_len_ratio", False):
assert getattr(config.eval, "sampling_step_ratio", None) is not None, "Must set both varying_seq_len_ratio and sampling_step_ratio"
config.sampling.steps = int(config.model.length * config.eval.sampling_step_ratio)
if getattr(config.eval, "ablation_config", False):
if config.parameterization == "ar":
rprint(f"WARNING!!!!! FORCING AR PARAMS")
config.trainer.ar_shift = True
config.model.full_attention = False
config.data.keep_tensordict_on_disk = True
if is_torch_cuda_available():
if any(x.lower() in torch.cuda.get_device_name().lower() for x in ["v100", "1080", "2080", "quadro", "titan"]) or torch.cuda.get_device_capability()[0] <= 7:
rprint(f"Using 2080Ti/V100, setting precision to fp32")
config.trainer.precision = "no"
config.model.force_optimized_native_attn = False
config.trainer.compile = False
if any(x.lower() in torch.cuda.get_device_name().lower() for x in ["2080", "quadro"]):
config.loader.eval_batch_size = config.loader.eval_batch_size // 7
config.loader.batch_size = config.loader.batch_size // 7
elif any(x.lower() in torch.cuda.get_device_name().lower() for x in ["1080", "titan"]):
config.loader.eval_batch_size = config.loader.eval_batch_size // 6
config.loader.batch_size = config.loader.batch_size // 6
else:
config.loader.eval_batch_size = config.loader.eval_batch_size // 2
config.loader.batch_size = config.loader.batch_size // 2
elif "a5000" in torch.cuda.get_device_name().lower() or "a4500" in torch.cuda.get_device_name().lower():
config.loader.eval_batch_size = config.loader.eval_batch_size // 2
config.loader.batch_size = config.loader.batch_size // 2
else:
rprint(f"Found {torch.cuda.get_device_name()}")
config.loader.eval_batch_size = config.loader.eval_batch_size // 2
config.loader.batch_size = config.loader.batch_size // 2
if getattr(config, "parametierzation", None) == "ar" and config.eval.cfg is not None:
config.loader.eval_batch_size = config.loader.eval_batch_size // 2
config.loader.batch_size = config.loader.batch_size // 2
config.loader.eval_batch_size = max(config.loader.eval_batch_size, 1)
config.loader.batch_size = max(config.loader.batch_size, 1)
if getattr(config, "parametierzation", None) == "ar":
config.trainer.compile = False
if getattr(config.sampling, "sampling_step_frac", None) is not None:
config.sampling.steps = int(config.model.length * config.sampling.sampling_step_frac)
rprint(f"Setting sampling steps to {config.sampling.steps}")
if os.environ.get("SUBMITIT_FOLDER") is not None or os.environ.get("CUSTOM_SBATCH_LAUNCHER", "0") == "1":
rprint(f'Using submitit folder: {os.environ.get("SUBMITIT_FOLDER", "")}, setting slurm=True')
config.slurm = True
if (config.debug is False or os.environ.get("HYDRA_RUN_DIR_NAME", None) is not None) and torch.distributed.is_torchelastic_launched():
config.trainer.restart_on_failure = True
rprint(f"Setting restart_on_failure to True")
if config.trainer.restart_on_failure and config.mode == 'train':
if os.environ.get("HYDRA_RUN_DIR", None) is None and os.environ.get("HYDRA_RUN_DIR_NAME", None) is None:
os.environ["HYDRA_RUN_DIR"] = config.output_dir
rprint(f"Setting HYDRA_RUN_DIR to {os.environ['HYDRA_RUN_DIR']}")
else:
rprint(f"Not setting HYDRA_RUN_DIR, already set to {os.environ.get('HYDRA_RUN_DIR', 'N/A')}, and HYDRA_RUN_DIR_NAME is set to {os.environ.get('HYDRA_RUN_DIR_NAME', 'N/A')}")
os.environ["RESTART_FAULT_TOLERANT"] = "1"
rprint(f"Setting RESTART_FAULT_TOLERANT to 1")
elif config.trainer.restart_on_failure:
rprint(f"Restart_on_failure is True, but mode is not 'train', so not setting restart fault tolerant")
relevant_vars = {}
for key, value in os.environ.items():
if "SLURM" in key or "NCCL" in key or "TORCH" in key:
relevant_vars[key] = value
config.env_vars = relevant_vars
if config.trainer.profile_memory:
config.trainer.max_steps = 2
if config.debug and config.trainer.force_enable_checkpointing is False and (config.trainer.ckpt_steps is None or config.trainer.ckpt_steps > 0):
config.trainer.ckpt_steps = 10000
rprint(f"Only checkpointing every {config.trainer.ckpt_steps} steps in debug mode")
if config.loader.global_batch_size is None:
config.loader.global_batch_size = config.loader.batch_size * config.trainer.accumulate_grad_batches * (1 if is_torch_xla_available() else get_world_size())
config.loader.eval_global_batch_size = config.loader.global_batch_size
if config.trainer.scale_lr_by_batch_size:
config.optim.lr = config.optim.lr * (config.loader.global_batch_size / 512)
rprint(f"Setting global batch size to {config.loader.global_batch_size}, lr to {config.optim.lr}")
if config.mode != 'train':
config.checkpointing.resume_wandb = False
config.wandb.resume = None
if config.trainer.use_spmd_distributed_checkpointing is None:
config.trainer.use_spmd_distributed_checkpointing = is_torch_xla_available() and config.trainer.xla_spmd
if config.trainer.disable_all_eval_generation:
config.eval.num_masking_viz_batches=0
config.eval.num_uncond_sample_batches=0
config.eval.num_sample_batches=0
config.eval.num_random_masking=0
config.eval.generate_samples=False
config.trainer.log_flops=False
config.eval.log_every_n_evals=-1
config.eval.log_every_n_fid = -1
config.model.image_model_fid_eval = False
rprint("Disabling all eval generation!!!")
if os.environ.get("XLA_IR_DEBUG", "0") == "1":
config.trainer.tpu_profile = True
if config.checkpointing_root_dir is not None:
assert "checkpoints" in config.checkpointing.save_dir
relative_path = Path(*Path(config.checkpointing.save_dir).relative_to(config.root_output_dir).parts[1:])
full_checkpointing_dir = Path(config.checkpointing_root_dir) / relative_path
if config.checkpointing_root_dir is not None:
old_save_dir = Path(config.output_dir) / "checkpoints"
full_checkpointing_dir.mkdir(parents=True, exist_ok=True)
try:
if old_save_dir.exists():
rprint(f"WARNING: Cannot create symlink from {old_save_dir} to {full_checkpointing_dir} because {old_save_dir} exists.")
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
old_save_dir = Path(*old_save_dir.parts[:-1]) / f"checkpoints_{timestamp}"
old_save_dir.symlink_to(full_checkpointing_dir, target_is_directory=True)
rprint(f"Created softlink from {old_save_dir} to {full_checkpointing_dir}")
# Create a symlink from the parent of full_checkpointing_dir named "original" back to config.output_dir
original_link = full_checkpointing_dir.parent / "original_output_dir"
if not original_link.exists():
original_link.symlink_to(Path(config.output_dir).resolve(), target_is_directory=True)
rprint(f"Created softlink from {original_link} to {config.output_dir}")
else:
rprint(f"WARNING: Symlink {original_link} already exists. Skipping creation.")
except OSError as e:
rprint(f"Error creating softlinks: {e}")
assert getattr(config.data, "allow_label", False) == getattr(config.trainer, "add_label", False) == (getattr(config.model, "add_labels", None) is not None) == getattr(config.eval, "class_conditional_fid", False), f"Mismatching values: data.allow_label={config.data.allow_label}, trainer.add_label={config.trainer.add_label}, model.add_labels={config.model.add_labels}, eval.class_conditional_fid={config.eval.class_conditional_fid}"
if getattr(config.loader, "num_eval_workers", None) is not None and config.loader.num_workers == 0:
rprint(f"Setting num_eval_workers to 0 because num_workers is 0")
config.loader.num_eval_workers = 0
if config.trainer.disable_all_checkpointing:
gprint("-"*50)
gprint(f"WARNING: DISABLING ALL CHECKPOINTING!!!!")
gprint("-"*50)
gprint(f"WARNING: DISABLING ALL CHECKPOINTING!!!!")
gprint("-"*50)
config.trainer.ckpt_steps = 100000000
if config.sampling.steps != config.sampling.max_sampling_steps:
rprint(f"WARNING!!!! steps {config.sampling.steps} != max_sampling_steps {config.sampling.max_sampling_steps}")
config.sampling.max_sampling_steps = config.sampling.steps
def get_latest_ckpt(config, input_dir):
if input_dir is None or not Path(input_dir).exists():
rprint(f"Project dir {input_dir} does not exist")
return None
if config.trainer.xla_spmd and is_torch_xla_available():
rprint(f"XLA SPMD detected, using XLA checkpointing")
if any(Path(input_dir).iterdir()):
rprint(f"Found existing files/folders in {input_dir}")
return input_dir
else:
rprint(f"No folders found in {input_dir}")
return None
folders = [str(folder) for folder in Path(input_dir).iterdir() if folder.is_dir() and ((folder / "model.safetensors").exists() or (folder / "config.yaml").exists())]
if len(folders) == 0:
rprint(f"No folders found in {input_dir}")
return None
def _inner(folder):
return list(map(int, re.findall(r"[\/]?([0-9]+)(?=[^\/]*$)", folder)))[0]
folders.sort(key=_inner)
rprint(f"Found folders: {folders}")
input_dir = folders[-1]
return input_dir
def is_sweep():
try:
subdir = HydraConfig.get().sweep.subdir
rprint(f"Found sweep subdir: {subdir}")
return True
except omegaconf.errors.InterpolationToMissingValueError:
return False
def get_sweep_run_name(config):
try:
subdir = HydraConfig.get().sweep.subdir
sweep_str = f"{subdir}_"
is_sweep = True
except omegaconf.errors.InterpolationToMissingValueError:
is_sweep = False
sweep_str = f"{os.environ.get('SLURM_JOB_ID', '')}_"
if getattr(config, "training", None) is not None and getattr(getattr(config, "training", None), "force_keys", None) is not None:
rprint("Using legacy keys")
forced_keys = set(config.training.force_keys)
else:
forced_keys = set(getattr(config.trainer, "forced_keys", []))
if is_sweep:
print(
f"Getting sweep keys: {HydraConfig.get().job.sweep_keys}, Tasks: {HydraConfig.get().overrides.task}, {getattr(config.trainer, 'forced_keys', [])}"
)
valid_keys = set(HydraConfig.get().job.sweep_keys)
for task in HydraConfig.get().overrides.task:
if task.removeprefix("+").split("=")[0] in valid_keys or task.removeprefix("+").split("=")[0] in forced_keys:
sweep_str += f"{task.removeprefix('+').split('=')[0].split('.')[-1]}={task.removeprefix('+').split('=')[1]}__"
if task.removeprefix("+").split("=")[0] in forced_keys:
forced_keys.remove(task.removeprefix("+").split("=")[0])
print(f"Forced key: {task.removeprefix('+').split('=')[0]}={task.removeprefix('+').split('=')[1]}")
for key in sorted(list(forced_keys)):
sweep_str += f"{key.split('.')[-1]}={OmegaConf.select(config, key)}__"
rprint(f"Sweep: {is_sweep=}, {sweep_str=}")
return "" if sweep_str == "" else sweep_str[:-2]
def save_config_to_ckpt(config, output_dir, model):
with try_except(write_error_to_file=True, clear_cuda_cache=True):
with read_write(config):
with open_dict(config):
config.state.ckpt_step = model.global_step
config.state.num_evals = model.num_evals
OmegaConf.save(config=config, f=Path(output_dir) / "config.yaml")
rprint(f"Saved global step {model.global_step}")
def determine_ckpt(config):
has_recent_ckpt = False
rprint(f"Looking at checkpoint path: {getattr(config.checkpointing, 'resume_ckpt_path', None)}")
if (
config.checkpointing.resume_from_ckpt
and (latest_ckpt := get_latest_ckpt(config, getattr(config.checkpointing, "resume_ckpt_path", None))) is not None
and (Path(latest_ckpt) / "config.yaml").exists()
):
ckpt_path = latest_ckpt
has_recent_ckpt = True
if config.slurm:
config.wandb.resume = "must"
rprint(f"Resuming from checkpoint {ckpt_path}")
elif config.checkpointing.resume_from_ckpt and getattr(config.checkpointing, "initial_resume_ckpt_path", None) is not None:
ckpt_path = config.checkpointing.initial_resume_ckpt_path
rprint(f"Resuming from initial checkpoint {ckpt_path}")
else:
ckpt_path = None
if ckpt_path is not None and (config.checkpointing.resume_wandb or has_recent_ckpt):
loaded = OmegaConf.load(Path(ckpt_path) / "config.yaml")
if loaded.wandb.id is not None:
config.wandb.id = str(loaded.wandb.id)
rprint(f"Found wandb id: {config.wandb.id}")
else:
rprint(f"No wandb id found in checkpoint {ckpt_path}")
if config.checkpointing.resume_wandb and config.wandb.id is not None:
config.wandb.resume = "must"
rprint(f"Resuming wandb, setting must, run id: {config.wandb.id}")
elif config.slurm and config.wandb.id is None:
if os.environ.get("SLURM_ARRAY_TASK_COUNT", "") != "" and int(os.environ.get("SLURM_ARRAY_TASK_COUNT", "")) > 1:
config.wandb.id = str(os.environ.get("SLURM_ARRAY_JOB_ID")) + f"_{os.environ.get('SLURM_ARRAY_TASK_ID')}"
else:
config.wandb.id = str(os.environ.get("SLURM_JOB_ID"))
rprint(f"Setting wandb id to {config.wandb.id}")
if config.checkpointing.initial_resume_ckpt_path is not None and config.checkpointing.resume_wandb:
assert config.wandb.id is not None
if config.ckpt is not None:
ckpt_path = config.ckpt
rprint(f"Running eval with checkpoint {ckpt_path}")
if config.wandb.id is not None:
config.wandb.id = str(config.wandb.id)
if config.wandb.id is None or getattr(config.trainer, "force_new_wandb_id", False):
config.wandb.id = wandb.util.generate_id()
config.wandb.resume = "allow"
rprint(f"Set wandb id: {config.wandb.id}")
rprint(f"Using wandb id: {config.wandb.id}")
subdir = get_sweep_run_name(config)
rprint(f"Wandb name: {config.wandb.name}, Wandb subdir: {subdir}")
if config.wandb.name == 'default':
config.wandb.name = None
else:
config.wandb.name = (
(f"{config.wandb.name}_" if config.wandb.name else "")
+ (f"{subdir}_" if (subdir is not None and subdir != "") else "")
+ f"{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}"
)
if getattr(config.wandb, "group", None) is None and subdir is not None and config.debug and os.environ.get("SLURM_ARRAY_JOB_ID", "") != "":
config.wandb.group = os.environ.get("SLURM_ARRAY_JOB_ID")
rprint(f"Wandb group: {config.wandb.group}")
return ckpt_path
def run(config, tokenizer):
import torch
from accelerate import (Accelerator, DataLoaderConfiguration,
DDPCommunicationHookType,
DistributedDataParallelKwargs,
FullyShardedDataParallelPlugin)
from accelerate.state import AcceleratorState
from accelerate.utils import GradientAccumulationPlugin, ProjectConfiguration
set_torch_defaults(config.trainer.benchmark)
set_env_vars(config)
update_config_before_resolution(config)
ckpt_path = determine_ckpt(config)
OmegaConf.resolve(config)
if is_torch_cuda_available():
check_gpu_memory_usage()
if is_torch_cuda_available():
rprint(f"pt={torch.__version__}, cuda={torch.version.cuda}, nccl={torch.cuda.nccl.version()}")
rprint(f"GPU={torch.cuda.get_device_name()}, device compute capabilities={torch.cuda.get_device_capability()}, pytorch compute capabilities={torch.cuda.get_arch_list()}")
elif is_torch_xla_available():
rprint(f"XLA Devices={get_tpu_devices()}")
rprint(
f"Initial GROUP_RANK: {os.environ.get('GROUP_RANK', 'N/A')}, RANK: {os.environ.get('RANK', 'N/A')}, LOCAL_RANK: {os.environ.get('LOCAL_RANK', 'N/A')}, WORLD_SIZE: {os.environ.get('WORLD_SIZE', 'N/A')}, MASTER_ADDR: {os.environ.get('MASTER_ADDR', 'N/A')}, MASTER_PORT: {os.environ.get('MASTER_PORT', 'N/A')}, TORCHELASTIC_RUN_ID: {os.environ.get('TORCHELASTIC_RUN_ID', 'N/A')}, TORCHELASTIC_RESTART_COUNT: {os.environ.get('TORCHELASTIC_RESTART_COUNT', 'N/A')}, TORCHELASTIC_MAX_RESTARTS: {os.environ.get('TORCHELASTIC_MAX_RESTARTS', 'N/A')}, LOCAL_WORLD_SIZE: {os.environ.get('LOCAL_WORLD_SIZE', 'N/A')}, Elastic: {torch.distributed.is_torchelastic_launched()}"
)
rprint(f"Computed Rank: {get_rank()}, Local Rank: {get_local_rank()}, World Size: {get_world_size()}")
# This lets us have start_timing and end_timing functions and a global enable/disable
# We always use torch.cuda.synchronize before/after as otherwise the timing is not very meaningful
sync_timing = (config.trainer.nvtx_profile and getattr(config.trainer, "sync_nvtx_timing", True)) or getattr(config.trainer, "sync_timing", False)
set_timing_builtins(enable=config.trainer.nvtx_profile, sync=sync_timing)
num_nodes = config.trainer.num_nodes
with open_dict(config):
config.trainer = OmegaConf.merge(config.trainer, dict(mixed_precision=config.trainer.precision, log_with="wandb", log_gradients=None))
if getattr(config.trainer, "process_dataloader_only", False):
gprint("Processing dataloader only")
train_ds, valid_ds = dataloader.get_dataloaders(config, tokenizer, device="cpu", skip_train=(config.mode == 'eval' and not config.eval.val_with_train_data))
gprint(f"Exiting after processing dataloader")
return
accelerator_project_config = ProjectConfiguration(
project_dir=config.output_dir,
logging_dir=config.logging_dir,
automatic_checkpoint_naming=config.checkpointing.use_automatic_naming,
save_on_each_node=False,
)
accelerate_kwargs = dict()
gradient_kwargs = dict()
if config.trainer.fsdp and not (config.trainer.xla_spmd and is_torch_xla_available()):
rprint("Using FSDP...")
if config.backbone == "llama" or config.backbone == "gemma":
os.environ["ACCELERATE_USE_FSDP"] = "true"
os.environ["FSDP_AUTO_WRAP_POLICY"] = "TRANSFORMER_BASED_WRAP"
os.environ["FSDP_BACKWARD_PREFETCH"] = "NO_PREFETCH" # Saved memory
os.environ["FSDP_CPU_RAM_EFFICIENT_LOADING"] = "true"
os.environ["FSDP_FORWARD_PREFETCH"] = "false"
os.environ["FSDP_OFFLOAD_PARAMS"] = "false"
os.environ["FSDP_SHARDING_STRATEGY"] = "FULL_SHARD"
os.environ["FSDP_STATE_DICT_TYPE"] = "SHARDED_STATE_DICT"
os.environ["FSDP_SYNC_MODULE_STATES"] = "true"
os.environ["FSDP_USE_ORIG_PARAMS"] = "true"
fsdp_plugin = FullyShardedDataParallelPlugin()
else:
os.environ["ACCELERATE_USE_FSDP"] = "true"
os.environ["FSDP_AUTO_WRAP_POLICY"] = "TRANSFORMER_BASED_WRAP" # or "SIZE_BASED_WRAP"
if config.backbone == "elm":
os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = "OpenELMDecoderLayer"
os.environ["FSDP_BACKWARD_PREFETCH"] = "BACKWARD_PRE"
os.environ["FSDP_SHARDING_STRATEGY"] = "HYBRID_SHARD_ZERO2"
else:
# Fastest but requires more memory: https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.BackwardPrefetch
os.environ["FSDP_BACKWARD_PREFETCH"] = "BACKWARD_PRE"
# See: https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.ShardingStrategy
os.environ["FSDP_SHARDING_STRATEGY"] = "HYBRID_SHARD_ZERO2"
os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = "DDiTBlock"
# SHARDED_STATE_DICT is a bit faster, but more complicated as later on we need to merge the shards.
from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullOptimStateDictConfig, FullStateDictConfig)
fsdp_plugin = FullyShardedDataParallelPlugin(
state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), # SHARDED_STATE_DICT
)
if config.trainer.compile or config.trainer.use_orig_params is True:
# https://github.com/huggingface/transformers/pull/24591/files
fsdp_plugin.use_orig_params = True
rprint("Using orig params for FSDP. This is required for torch.compile to work.")
accelerate_kwargs["fsdp_plugin"] = fsdp_plugin
gradient_kwargs["sync_each_batch"] = False
if getattr(config.trainer, "fsdp_sync_each_batch", False): # Reduce memory usage: https://huggingface.co/docs/accelerate/en/concept_guides/gradient_synchronization#nosync-requires-additional-gpu-memory-when-using-fsdp
rprint("Using sync each batch for Chameleon")
gradient_kwargs["sync_each_batch"] = True
elif config.trainer.xla_spmd is False: # For XLA FSDP, we init where we normally prepare()
rprint("Using DDP...")
ddp_kwargs = DistributedDataParallelKwargs(
find_unused_parameters=config.trainer.find_unused_parameters,
comm_hook=DDPCommunicationHookType.BF16,
static_graph=config.trainer.accumulate_grad_batches == 1,
gradient_as_bucket_view=True,
)
# bucket_cap_mb=32,
# Not needed right now
from datetime import timedelta
from accelerate.utils import InitProcessGroupKwargs
init_process_group_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=1800))
accelerate_kwargs["kwargs_handlers"] = [ddp_kwargs, init_process_group_kwargs]
else:
rprint(f"Did not choose DDP or FSDP.")
if config.trainer.accumulate_grad_batches <= 0:
gprint("WARNING!!!!!! Accumulate grad batches is <= 0, setting to 1")
config.trainer.accumulate_grad_batches = 1
gradient_accumulation_plugin = GradientAccumulationPlugin(
num_steps=config.trainer.accumulate_grad_batches,
adjust_scheduler=False, # We manually adjust our LR for accumulate_grad_batches
sync_with_dataloader=False,
**gradient_kwargs
)
if config.trainer.mixed_precision == "bf16" and (is_torch_cuda_available() and not torch.cuda.is_bf16_supported()):
rprint(f"No BF16 GPU found, falling back to FP16")
config.trainer.mixed_precision = "fp16"
if config.trainer.mixed_precision == "fp32":
config.trainer.mixed_precision = "no"
else:
if is_torch_xla_available():
os.environ["ACCELERATE_DOWNCAST_BF16"] = "true"
rprint(f"Mixed precision: {config.trainer.mixed_precision}")
if config.seed is None or getattr(config.eval, 'set_random_gen_seed', False):
# do not ask why, has to do something with seeds being reset by val_epoch_end so if you don't execute this code, your generations in val_epoch_end will be same across gpus
accelerate_kwargs["rng_types"] = []
rprint("No seed provided, disabling accelerate RNG synchronization")
accelerator = Accelerator(
mixed_precision=config.trainer.mixed_precision,
log_with=config.trainer.log_with,
project_config=accelerator_project_config,
gradient_accumulation_plugin=gradient_accumulation_plugin,
dataloader_config=DataLoaderConfiguration(split_batches=False, dispatch_batches=False, non_blocking=False),
**accelerate_kwargs,
)
gprint(f"Distributed Type: {accelerator.distributed_type}, Accelerator state: {AcceleratorState()}")
num_processes = AcceleratorState().num_processes
if getattr(config.trainer, "global_num_warmup_steps", None) is not None:
rprint(f"Global num_warmup_steps was: {config.lr_scheduler.num_warmup_steps}. Applying to num_warmup_steps")
config.lr_scheduler.num_warmup_steps = config.trainer.global_num_warmup_steps
if getattr(config.trainer, "global_num_training_steps", None) is not None:
rprint(f"Global num_training_steps was: {config.lr_scheduler.num_training_steps}. Applying to num_training_steps")
config.lr_scheduler.num_training_steps = config.trainer.global_num_training_steps
if not config.trainer.disable_adjust_num_warmup_steps:
rprint(f"Original num_warmup_steps was: {config.lr_scheduler.num_warmup_steps}")
config.lr_scheduler.num_warmup_steps = config.lr_scheduler.num_warmup_steps * num_processes
rprint(f"Setting num_warmup_steps to: {config.lr_scheduler.num_warmup_steps}")
if hasattr(config.lr_scheduler, "num_training_steps"):
rprint(f"Original num_training_steps was: {config.lr_scheduler.num_training_steps}")
config.lr_scheduler.num_training_steps = config.lr_scheduler.num_training_steps * num_processes
rprint(f"Setting num_training_steps to: {config.lr_scheduler.num_training_steps}")
assert config.trainer.allow_dynamic_nodes or (os.environ.get("XLA_USE_SPMD", "0") == "1") or accelerator.num_processes == (
config.trainer.devices * num_nodes
), f"Expected {config.trainer.devices * num_nodes} GPUs but got {accelerator.num_processes} processes."
compute_dtyle = torch.float32
if accelerator.mixed_precision == "fp16":
compute_dtyle = torch.float16
elif accelerator.mixed_precision == "bf16":
compute_dtyle = torch.bfloat16
if compute_dtyle != torch.bfloat16:
rprint(f"WARNING!!!! Compute dtype is: {compute_dtyle}")
else:
rprint(f"Compute dtype is: {compute_dtyle}")
if is_main_process():
instantiate_wandb(config, accelerator)
run_cmd = get_run_cmd(config)
with open_dict(config):
config.trainer.devices = accelerator.num_processes
config.trainer.dtype = str(compute_dtyle)
if hasattr(config, "state"):
config.state.cmd = run_cmd
else:
config.state = OmegaConf.create(dict(cmd=run_cmd))
OmegaConf.set_readonly(config, True)
if getattr(config.trainer, "attach_oom_observer", False):
from torchtnt.utils.oom import attach_oom_observer
attach_oom_observer(output_dir=str(os.getcwd()), trace_max_entries=500000)
rprint(f"Attached OOM observer to {os.getcwd()}")
train_ds, valid_ds = dataloader.get_dataloaders(config, tokenizer, device=accelerator.device, skip_train=(config.mode == 'eval' and not config.eval.val_with_train_data))
model = Diffusion(config=config, tokenizer=valid_ds.tokenizer, device=accelerator.device)
if is_main_process():
print_params(model.backbone)
try:
if getattr(config.model, "image_model", False) is False:
_print_batch(train_ds, valid_ds, tokenizer)
except:
pass
get_ema_path = lambda x: Path(x) / "ema.ckpt"
SAMPLER_NAME = "weighted_dataset_sampler"
def save_model_hook(models, weights, output_dir):
nonlocal model, accelerator, train_ds
if is_main_process():
with try_except(write_error_to_file=True):
if getattr(model, "ema", None) is not None:
torch.save(accelerator.unwrap_model(model).ema.state_dict(), get_ema_path(output_dir))
rprint(f"Saved EMA to {get_ema_path(output_dir)}")
save_config_to_ckpt(config, output_dir, model)
with try_except(write_error_to_file=True):
if config.data.use_weighted_tensordict_sampler:
from accelerate.utils import save
output_sampler_file = output_dir.joinpath(f"{SAMPLER_NAME}_train.bin")
save(train_ds.sampler.state_dict(), output_sampler_file, save_on_each_node=False, safe_serialization=False)
rprint(f"Sampler state for dataloader saved in {output_sampler_file}")
initial_global_step = None
def load_model_hook(models, input_dir):
nonlocal initial_global_step, model, train_ds
config_path = os.path.join(input_dir, "config.yaml")
ckpt_config = OmegaConf.load(config_path)
initial_global_step = ckpt_config.state.ckpt_step
model.global_step = initial_global_step
try:
if hasattr(config.state, "num_evals"):
model.num_evals = config.state.num_evals
except Exception as e:
rprint(f"Error loading model: {e}")
rprint(f"Loaded global step {initial_global_step}")
state_dict = None
if getattr(config.checkpointing, "load_from_old_attention_format", False):
state_dict = load_file(os.path.join(input_dir, "model.safetensors"))
state_dict = convert_state_dict_keys(state_dict)
if getattr(model, "ema", None) is not None:
if get_ema_path(input_dir).exists():
rprint(f"Loading EMA from {get_ema_path(input_dir)}")
model.ema.load_state_dict(torch.load(get_ema_path(input_dir), map_location='cpu'))
else:
rprint(f"No EMA found, initializing EMA with state_dict")
if state_dict is None:
state_dict = load_file(os.path.join(input_dir, "model.safetensors"))
# We likely don't need the unwrap, but just to be safe
accelerator.unwrap_model(models[0]).load_state_dict(state_dict)
from models.ema import EMAModel
model.ema = EMAModel(accelerator.unwrap_model(models[0]).parameters(), decay=config.trainer.ema)
if config.data.use_weighted_tensordict_sampler and not is_torch_xla_available(): # and not config.eval.test_eval_speed:
input_sampler_file = Path(input_dir).joinpath(f"{SAMPLER_NAME}_train.bin")
if train_ds is not None and input_sampler_file.exists():
train_ds.sampler.load_state_dict(torch.load(input_sampler_file))
rprint("All dataloader sampler states loaded successfully")
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
model.init_dataloader(train_ds, valid_ds)
model.set_accelerator(accelerator, ckpt_path)
model.set_callbacks()
if getattr(config.checkpointing, "load_from_text_model", None) is not None:
rprint(f"Loading from text model")
model.custom_load_checkpoint()
if getattr(config.checkpointing, "load_from_lightning_ckpt", None) is not None:
ckpt = torch.load(config.checkpointing.load_from_lightning_ckpt)
initial_global_step = ckpt["global_step"]
state_dict_ = {k.removeprefix("backbone."): v for k, v in ckpt["state_dict"].items() if "backbone" in k}
state_dict_ = {k.replace(".attn_", ".attention.attn_"): v for k, v in state_dict_.items()}
accelerator.unwrap_model(model.backbone).load_state_dict(state_dict_)
if config.trainer.ema > 0:
model.ema.load_state_dict(ckpt["ema"])
rprint(f"Loaded lightning ckpt: {config.checkpointing.load_from_lightning_ckpt}")
if initial_global_step is not None:
# The load_hooks are before accelerate does it's loading and it overwrites model.global_step if we set it there
model.global_step = initial_global_step
rprint(f"Set global step to {initial_global_step}")
contexts = []
if config.trainer.nvtx_profile:
contexts.append(torch.autograd.profiler.emit_nvtx(record_shapes=True))
if config.trainer.profile_memory:
contexts.append(profile_memory())
using_torch_elastic = torch.distributed.is_torchelastic_launched()
if using_torch_elastic:
rprint(f"Torchelastic launched: {torch.distributed.is_torchelastic_launched()}")
contexts.append(ErrorHandler())
with ExitStack() as stack:
for ctx in contexts:
stack.enter_context(ctx)
rprint(f"output_dir: {config.output_dir}")
model.to(accelerator.device)
if config.mode == 'train':
model.train()
elif config.mode == 'eval':
if config.eval.standalone_fid:
model.validate(None)
else:
model.validate(None)
elif config.mode == 'zero-shot-eval':
model.zero_shot_eval()
else:
raise ValueError(f"Invalid mode: {config.mode}")
accelerator.end_training()
def get_run_cmd(config):
orig_argv = deepcopy(sys.argv)
prepend_argv = []
if "HYDRA_RUN_DIR" in os.environ:
prepend_argv.append(f"HYDRA_RUN_DIR='{os.environ['HYDRA_RUN_DIR']}'")
else:
prepend_argv.append(f"HYDRA_RUN_DIR='{str(Path(config.output_dir).resolve())}'")
if orig_argv[1].startswith("experiments=["):
orig_argv[1] = orig_argv[1].removeprefix("experiments=[").removesuffix("]")
orig_argv[1] = f"experiments=\'[{orig_argv[1]}]\'"
if os.environ.get("CUSTOM_SBATCH_LAUNCHER", "0") == "1":
sbatch_script_path = 'scripts/slurm.sh'
orig_argv.pop(0)
orig_argv = ["sbatch", f"--nodes={os.environ.get('SLURM_NNODES', '1')}", f"--gpus-per-node={os.environ.get('SLURM_GPUS_PER_NODE', '1')}", f"--partition={os.environ.get('SLURM_JOB_PARTITION', 'all')}", sbatch_script_path] + orig_argv
else:
prepend_argv.append("accelerate launch")
full_run_cmd = " ".join(prepend_argv + orig_argv)
rprint(f"Full run cmd: {full_run_cmd}")
return full_run_cmd
@hydra.main(version_base=None, config_path=CONFIG_PATH, config_name="config")
@try_except()
def main(config):
if is_sweep():
print(f"Checking if we need to requeue for job id {os.environ['SLURM_JOB_ID']}")
from unidisc.utils.slurm_requeue import check_requeue
check_requeue()
print(f"Done checking if we need to requeue for job id {os.environ['SLURM_JOB_ID']}.")
"""Main entry point for trainer."""
import torch # Causes issue pickling if imported by default.
if is_torch_xla_available():
builtins.HAS_XLA_SPAWNED = True
os.environ['PJRT_DEVICE'] = 'TPU'
if config.trainer.precision == "bf16":
os.environ['XLA_USE_BF16'] = '1'
if config.devices == 1 and config.trainer.xla_spmd is False and config.trainer.fsdp is False:
os.environ['TPU_PROCESS_BOUNDS'] = '1,1,1'
os.environ['TPU_VISIBLE_CHIPS'] = '0'
gprint(f"Setting TPU_PROCESS_BOUNDS: {os.environ['TPU_PROCESS_BOUNDS']}")
gprint(f"Setting TPU_VISIBLE_CHIPS: {os.environ['TPU_VISIBLE_CHIPS']}")
if config.trainer.tpu_eager:
os.environ['XLA_USE_EAGER_DEBUG_MODE'] = '1'
if config.trainer.tpu_compile_debug:
os.environ['PT_XLA_DEBUG'] = '1'
os.environ['PT_XLA_DEBUG_LEVEL'] = '2'
os.environ['XLA_IR_DEBUG'] = '1'
os.environ['XLA_HLO_DEBUG'] = '1'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
os.environ['TF_CPP_VMODULE'] = 'xla_graph_executor=5,pjrt_computation_client=3'
# We intentionally set these after to avoid import side effects
spmd_mesh, axis_names, num_nodes = None, None, None
if config.trainer.xla_spmd:
import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
from accelerate import PartialState
from torch_xla._internal import tpu
auto_spmd = getattr(config.trainer, "auto_spmd", False)
xr.use_spmd(auto=auto_spmd) # Auto causes a crash
force_global_devices = getattr(config.trainer, "force_global_devices", None)
force_local_devices = getattr(config.trainer, "force_local_devices", None)
assert (force_global_devices is None) == (force_local_devices is None), "Must set both or neither"
if force_global_devices is not None:
num_global_devices = force_global_devices
num_local_devices = force_local_devices
gprint(f"Using force global devices: num_global_devices={num_global_devices}, num_local_devices={num_local_devices}")
else:
num_global_devices = xr.global_runtime_device_count()
num_local_devices = tpu.num_available_devices()
assert num_global_devices == tpu.num_expected_global_devices()
assert tpu.num_available_devices() == tpu.num_available_chips() == tpu.num_local_processes()
num_nodes = num_global_devices // num_local_devices
spmd_mesh_shape = getattr(config.trainer, "spmd_mesh", None)
if spmd_mesh_shape is None:
spmd_mesh_shape = (num_nodes, num_local_devices, 1)
if getattr(config.trainer, "force_disable_replicas", False):
spmd_mesh_shape = (1, num_global_devices, 1)
rprint(f"Forcing disable replicas: {spmd_mesh_shape}")
if auto_spmd is False:
if getattr(config.trainer, "spmd_multislice", None) is not None:
from torch_xla.distributed.spmd import HybridMesh
ici_mesh_shape = spmd_mesh_shape
dcn_mesh_shape = (config.trainer.spmd_multislice, 1, 1)
spmd_mesh = HybridMesh(ici_mesh_shape=ici_mesh_shape, dcn_mesh_shape=dcn_mesh_shape, axis_names=('data','fsdp','tensor'))
rprint(f"Using multislice: {config.trainer.spmd_multislice}: {ici_mesh_shape} {dcn_mesh_shape}, {spmd_mesh.shape()}")
else:
spmd_mesh = xs.Mesh(np.array(range(num_global_devices)), spmd_mesh_shape, ('dcn', 'fsdp', 'model'))
xs.set_global_mesh(spmd_mesh)
config.devices = 1
config.nodes = 1
with read_write(config):
with open_dict(config):
config.state = OmegaConf.create(dict(spmd_mesh=spmd_mesh_shape))
config.state.axis_names = axis_names
config.state.num_nodes = num_nodes
config.state.num_global_devices = num_global_devices
config.state.num_local_devices = num_local_devices
config.state.worker_ips = tpu.get_worker_ips()
if os.environ.get("TPU_NAME") is not None:
config.state.tpu_name = os.environ.get("TPU_NAME")
if config.trainer.tpu_eager:
import torch_xla
torch_xla.experimental.eager_mode(True)
if config.trainer.tpu_profile:
if config.trainer.tpu_profile_markers:
os.environ['XLA_IR_DEBUG'] = '1'
os.environ['XLA_HLO_DEBUG'] = '1'
import torch_xla.debug.profiler as xp
server = xp.start_server(9012)
if config.trainer.tpu_cache:
import torch_xla.runtime as xr
readonly = not is_main_process()
rprint(f"Initializing TPU cache with readonly={readonly}")
xr.initialize_cache(str((Path.home() / '.cache' / 'unidisc' / f"tpu_{get_rank()}_{get_hostname().replace('-', '_')}").resolve()), readonly=readonly)
if config.trainer.enable_jax_smi:
initialise_tracking()
rprint("Initializing jax-smi")
from unidisc.utils.logging_utils import set_logger
set_logger(f"{__name__} {get_slurm_log_prefix()}", Path(f"{get_slurm_filename_info()}_{get_hostname().replace('-', '_')}.out"))
if is_torch_xla_available():
import torch_xla.runtime as xr
gprint(
f"Computed Rank: {get_rank()}, "
f"Is Main Process: {is_main_process()}, "
f"Is Local Main Process: {is_local_main_process()}, "
f"XLA world size: {xr.world_size()}, "
f"XLA Model Ordinal: {xm.get_ordinal()}, "
f"XLA Global Ordinal: {xr.global_ordinal()}, "
f"XLA Supported Devices: {xm.get_xla_supported_devices()}, "
f"Accelerate Local Process Index: {PartialState().local_process_index}, "
f"Task ID: {tpu.task_id()}, "
f"Worker ID: {tpu.worker_id()} "
f"global device count: {xr.global_runtime_device_count()}, "
f"local process count: {xr.local_process_count()}, "
f"local device count: {xr.local_device_count()}, "
f"addressable device count: {xr.addressable_device_count()}, "
f"num_expected_global_devices: {tpu.num_expected_global_devices()}, "
f"num_available_devices: {tpu.num_available_devices()}, "
f"num_available_chips: {tpu.num_available_chips()}, "
f"num_local_processes: {tpu.num_local_processes()}, "
f"process_bounds_size: {tpu.process_bounds_size()}, "
f"get_worker_ips: {tpu.get_worker_ips()}, "
f"Computed Num Nodes: {num_nodes}, "
f"Specified Mesh: {spmd_mesh_shape}, "
f"Specified Mesh Axes: {axis_names}"
)
gprint(f"LIBTPU_INIT_ARGS: {os.environ.get('LIBTPU_INIT_ARGS', 'None')}")
gprint(f"XLA_FLAGS: {os.environ.get('XLA_FLAGS', 'None')}")
if getattr(config.trainer, "disable_ddp_optimizer", False):
torch._dynamo.config.optimize_ddp = False
if config.seed is not None:
if config.mode == 'eval':
config.seed = config.seed + 1000 * int(get_rank())
else:
config.seed = config.seed + int(get_rank())
np.random.seed(config.seed)
random.seed(config.seed)
torch.manual_seed(config.seed)
if is_torch_cuda_available():
# TODO: Is seed all desired? Does it set the same one on all GPUs even in multi-process?
torch.cuda.manual_seed_all(config.seed)
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
xm.set_rng_state(config.seed)
gprint(f"Set seed: {config.seed}")
else:
rprint("No seed provided")
_print_config(config, resolve=True, save_cfg=True)
with open(f"env_vars_{get_slurm_filename_info()}_{get_hostname().replace('-', '_')}.txt", "w") as f:
for key, value in os.environ.items():
f.write(f"{key}={value}\n")
tokenizer = dataloader.get_tokenizer(config)
if "tokens" in config.data.train and (config.loader.num_workers > 0 or getattr(config.data, "force_mp_spawn", False)):
from torch import multiprocessing as mp
try:
rprint(f"Start already method set to: {mp.get_start_method()}")
except:
mp.set_start_method("spawn")
rprint(f"Start method set to: {mp.get_start_method()}")
rprint(f"Mode: {config.mode}")
if config.mode == "sample_eval":
generate_samples(config, tokenizer)
else:
try:
run(config, tokenizer)
except Exception as e:
rprint(f"Traceback: {traceback.format_exc()}")
rprint(f"Exception: {e}")
timestamp = int(__import__("time").time_ns())
error_filepath = f"exception_{timestamp}_{process_file_prefix()}.out"
with open(error_filepath, "w") as file:
file.write(traceback.format_exc())
rprint(f"See error file {Path(error_filepath).resolve()} for traceback")
if is_torch_xla_available():
exit(1)
if ("SLURM_JOB_ID" not in os.environ) and ("RESTART_FAULT_TOLERANT" not in os.environ) and not is_torch_xla_available():
gprint(f"Entering debugger")
breakpoint(traceback=e.__traceback__)
else:
rprint(f"Not breaking, SLURM_JOB_ID: {os.environ.get('SLURM_JOB_ID')}, RESTART_FAULT_TOLERANT: {os.environ.get('RESTART_FAULT_TOLERANT')}")
if "RESTART_FAULT_TOLERANT" in os.environ:
sigterm_handler = signal.getsignal(signal.SIGTERM)
if callable(sigterm_handler):
rprint(f"Calling SIGTERM handler")
sigterm_handler(signal.SIGTERM, None)
try:
if config.trainer.num_nodes > 1 and config.debug is False and is_main_process():
wandb.alert(title="Exception!", text=f"{e}, {traceback.format_exc()}")
except:
pass
raise e
finally:
pass
if __name__ == "__main__":
main()