import os
import sys
from contextlib import ExitStack
from pathlib import Path

from constants import CONFIG_PATH, LIB_DIR
sys.path.append(str(LIB_DIR / "hydra_submitit_launcher"))

import builtins
import random
import re
import signal
import traceback
from copy import deepcopy
from datetime import datetime

import hydra
import numpy as np
import omegaconf
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf, open_dict, read_write
from safetensors.torch import load_file, save_file

import dataloader
from model import Diffusion
import utils
import wandb
from decoupled_utils import (check_gpu_memory_usage, get_hostname,
                             get_local_rank, get_rank, get_slurm_filename_info,
                             get_slurm_log_prefix, get_tpu_devices,
                             get_world_size, gprint, is_local_main_process,
                             is_main_process, is_torch_cuda_available,
                             is_torch_xla_available, print_params,
                             process_file_prefix, profile_memory, rank_zero_fn,
                             rprint, set_global_breakpoint, set_global_exists,
                             set_timing_builtins, try_except)
from utils import (ErrorHandler, _print_config, convert_state_dict_keys, set_omega_conf_resolvers, set_torch_defaults)

# Only needed when debugging hydra
# os.environ["HYDRA_FULL_ERROR"] = "1"

set_global_breakpoint()  # Overrides breakpoint() to use ipdb.set_trace() instead and handle distributed training
set_global_exists()
set_omega_conf_resolvers()

if is_torch_xla_available():
    from jax_smi import initialise_tracking

def _load_from_checkpoint(config, tokenizer):
    OmegaConf.resolve(config)
    if "hf" in config.backbone:
        return Diffusion(config=config, tokenizer=tokenizer).to("cuda")

    return Diffusion.load_from_checkpoint(config.eval.checkpoint_path, tokenizer=tokenizer, config=config)

@rank_zero_fn
def _print_batch(train_ds, valid_ds, tokenizer, k=64):
    for dl_type, dl in [("train", train_ds), ("valid", valid_ds)]:
        rprint(f"Printing {dl_type} dataloader batch.")
        batch = next(iter(dl))
        rprint("Batch input_ids.shape", batch["input_ids"].shape)
        first = batch["input_ids"][0, :k]
        last = batch["input_ids"][0, -k:]
        rprint(f"First {k} tokens:", tokenizer.decode(first))
        rprint("ids:", first)
        rprint(f"Last {k} tokens:", tokenizer.decode(last))
        rprint("ids:", last)


def generate_samples(config, tokenizer):
    rprint("Generating samples.")
    model = _load_from_checkpoint(config=config, tokenizer=tokenizer)
    model.gen_ppl_metric.reset()
    if config.eval.disable_ema:
        rprint("Disabling EMA.")
        model.ema = None
    stride_length = config.sampling.stride_length
    num_strides = config.sampling.num_strides
    for _ in range(config.sampling.num_sample_batches):
        if config.sampling.semi_ar:
            _, intermediate_samples, _ = model.restore_model_and_semi_ar_sample(
                stride_length=stride_length, num_strides=num_strides, dt=1 / config.sampling.steps
            )
            text_samples = intermediate_samples[-1]
            # Note: Samples generated using semi-ar method
            # need to to be processed before computing generative perplexity
            # since these samples contain numerous <|endoftext|> tokens
            # and diffusion.compute_generative_perplexity() discards
            # any text after the first EOS token.
        else:
            samples = model.restore_model_and_sample(num_steps=config.sampling.steps)
            text_samples = model.tokenizer.batch_decode(samples)
            model.compute_generative_perplexity(text_samples)
    
    rprint("Text samples:", text_samples)
    if not config.sampling.semi_ar:
        rprint("Generative perplexity:", model.gen_ppl_metric.compute())
    return text_samples


def instantiate_wandb(config, accelerator):
    if is_torch_xla_available():
        gprint("Initializing wandb for XLA")
    if config.mode == 'eval':
        config.wandb.project = f"{config.wandb.project}-eval"
    elif config.mode == 'zero-shot-eval':
        config.wandb.project = f"{config.wandb.project}-zero-shot-eval"

    if config.wandb.group is not None:
        config.wandb.group = str(config.wandb.group)

    # We need to initialize the trackers we use, and also store our configuration.
    # The trackers initializes automatically on the main process.
    wandb_kwargs = dict(config.wandb)

    if getattr(config, "sweep_id", None) is not None:
        rprint(f"Setting Wandb group to {config.sweep_id}")
        wandb_kwargs["group"] = config.sweep_id
    del wandb_kwargs["project"]
    accelerator.init_trackers(
        config.wandb.project, config=OmegaConf.to_container(config, resolve=True, throw_on_missing=True), init_kwargs=dict(wandb=wandb_kwargs)
    )

    if getattr(config.trainer, "log_code", True) and is_main_process():
        if "matrix" in get_hostname():
            rprint(f"Not logging code to wandb on {get_hostname()}")
        else:
            rprint(f"Logging code to wandb from {Path(__file__).parent}")
            try:
                wandb.run.log_code(
                    root=str(Path(__file__).parent),
                    include_fn=lambda path: any(path.endswith(f) for f in (".py", ".yaml", ".yml", ".txt", ".md")),
                    exclude_fn=lambda path, root: any(x in os.path.relpath(path, root) for x in ("output", "multirun", "logs", "wandb")),
                )
            except Exception as e:
                rprint(f"Failed to log code to wandb: {e}")

    with open_dict(config):
        try:
            config.wandb_url = wandb.run.get_url()
            wandb.define_metric("global_samples")
            wandb.define_metric("effective_global_tokens")
            wandb.define_metric("effective_global_step")
            wandb.define_metric("train_metrics/samples")
            wandb.define_metric("trainer/loss", step_metric="global_samples")
        except Exception as e:
            rprint(f"Failed to get wandb url: {e}")

def instantiate_model(config, tokenizer):
    model = _load_from_checkpoint(config=config, tokenizer=tokenizer)
    if config.eval.disable_ema:
        rprint("Disabling EMA.")
        model.ema = None

    return model

def gconf(config, attr):
    return getattr(config, attr, None)


def has_ckpt(config, attr):
    return gconf(config, attr) is not None and utils.fsspec_exists(gconf(config, attr))


def set_env_vars(config):
    import torch
    hostname = __import__("socket").gethostname()
    rprint(f"Starting Training on {hostname}")
    import torch
    # os.environ["TORCHINDUCTOR_CACHE_DIR"] = str((Path.home() / ".cache" / "torchinductor").resolve())

    if not is_torch_xla_available():
        try:
            # Applies the equivalent of ulimit -l unlimited to this process [and children]
            # This caused a significant amount of pain to figure out
            import resource
            soft, hard = resource.getrlimit(resource.RLIMIT_MEMLOCK)
            resource.setrlimit(resource.RLIMIT_MEMLOCK, (hard, hard))
            if is_local_main_process():
                gprint(f"Successfully set RLIMIT_MEMLOCK to {hard}")
        except ValueError as e:
            rprint(f"Failed to set RLIMIT_MEMLOCK: {e}")
        except resource.error as e:
            rprint(f"Error setting RLIMIT_MEMLOCK: {e}")
    else:
        rprint(f"Not setting RLIMIT_MEMLOCK on XLA")

    if "matrix-3-28" in hostname or "matrix-3-26" in hostname:
        rprint(f"Disabling NCCL P2P")
        os.environ["NCCL_P2P_DISABLE"] = "1"

    if os.environ.get("TORCH_DISTRIBUTED_DEBUG", "") != "":
        assert False, f"TORCH_DISTRIBUTED_DEBUG is set to: {os.environ.get('TORCH_DISTRIBUTED_DEBUG')}. Please unset it as it starts a gloo backend."

    if config.model.use_spda_attn:
        os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "1"
        os.environ["TORCH_CUDNN_MHA_ENABLED"] = "1"
        rprint("Setting SPDA Flags")

    if config.trainer.detect_anomaly:
        torch.autograd.set_detect_anomaly(True)

def update_config_before_resolution(config):
    import torch
    if hasattr(config, "training"):
        rprint(f"'training' has been refactored to 'trainer'. Please update the config.")
        
    with open_dict(config):
        config.output_dir = os.getcwd()
        config.logging_dir = os.getcwd()
        if config.model.use_kv_cache is False and config.mode == "eval" and config.loader.eval_batch_size > 1:
            config.loader.eval_batch_size = max(config.loader.eval_batch_size, 16)
        
        # todo revert?
        if getattr(config.eval, 'txt_img_ratio', None) is not None:
            # 2,1,0.5,0.25
            tot = config.model.length
            # if its 2:1, then distribute the tokens as 2/3, 1/3
            # if its 1:1, then distribute the tokens as 1/2, 1/2
            # if its 0.5:1, then distribute the tokens as 2/3, 1/3
            # if its 0.25:1, then distribute the tokens as 1/4, 3/4
            if config.eval.txt_img_ratio == 2:
                # do first 2/3 tokens as text, last 1/3 as image
                config.model.txt_length = int(tot * 2/3)
            elif config.eval.txt_img_ratio == 1:
                config.model.txt_length = int(tot / 2)
            elif config.eval.txt_img_ratio == 0.5:
                config.model.txt_length = int(tot * 2/3)
            elif config.eval.txt_img_ratio == 0.25:
                config.model.txt_length = int(tot / 4)
            config.model.img_length = tot - config.model.txt_length
            config.model.length = config.model.txt_length + config.model.img_length
            # config.eval.attention_caching_txt_to_img_ratio = config.model.txt_length // 20
            
        if getattr(config.eval, "varying_seq_len_ratio", False):
            assert getattr(config.eval, "sampling_step_ratio", None) is not None, "Must set both varying_seq_len_ratio and sampling_step_ratio"
            config.sampling.steps = int(config.model.length * config.eval.sampling_step_ratio)

        if getattr(config.eval, "ablation_config", False):
            if config.parameterization == "ar":
                rprint(f"WARNING!!!!! FORCING AR PARAMS")
                config.trainer.ar_shift = True
                config.model.full_attention = False

            config.data.keep_tensordict_on_disk = True
            if is_torch_cuda_available():
                if any(x.lower() in torch.cuda.get_device_name().lower() for x in ["v100", "1080", "2080", "quadro", "titan"]) or torch.cuda.get_device_capability()[0] <= 7:
                    rprint(f"Using 2080Ti/V100, setting precision to fp32")
                    config.trainer.precision = "no"
                    config.model.force_optimized_native_attn = False
                    config.trainer.compile = False
                    if any(x.lower() in torch.cuda.get_device_name().lower() for x in ["2080", "quadro"]):
                        config.loader.eval_batch_size = config.loader.eval_batch_size // 7
                        config.loader.batch_size = config.loader.batch_size // 7
                    elif any(x.lower() in torch.cuda.get_device_name().lower() for x in ["1080", "titan"]):
                        config.loader.eval_batch_size = config.loader.eval_batch_size // 6
                        config.loader.batch_size = config.loader.batch_size // 6
                    else:
                        config.loader.eval_batch_size = config.loader.eval_batch_size // 2
                        config.loader.batch_size = config.loader.batch_size // 2
                elif "a5000" in torch.cuda.get_device_name().lower() or "a4500" in torch.cuda.get_device_name().lower():
                    config.loader.eval_batch_size = config.loader.eval_batch_size // 2
                    config.loader.batch_size = config.loader.batch_size // 2
                else:
                    rprint(f"Found {torch.cuda.get_device_name()}")
                    config.loader.eval_batch_size = config.loader.eval_batch_size // 2
                    config.loader.batch_size = config.loader.batch_size // 2
            
            if getattr(config, "parametierzation", None) == "ar" and config.eval.cfg is not None:
                config.loader.eval_batch_size = config.loader.eval_batch_size // 2
                config.loader.batch_size = config.loader.batch_size // 2

            config.loader.eval_batch_size = max(config.loader.eval_batch_size, 1)
            config.loader.batch_size = max(config.loader.batch_size, 1)

            if getattr(config, "parametierzation", None) == "ar":
                config.trainer.compile = False
            
        if getattr(config.sampling, "sampling_step_frac", None) is not None:
            config.sampling.steps = int(config.model.length * config.sampling.sampling_step_frac)
            rprint(f"Setting sampling steps to {config.sampling.steps}")
        
        if os.environ.get("SUBMITIT_FOLDER") is not None or os.environ.get("CUSTOM_SBATCH_LAUNCHER", "0") == "1":
            rprint(f'Using submitit folder: {os.environ.get("SUBMITIT_FOLDER", "")}, setting slurm=True')
            config.slurm = True

        if (config.debug is False or os.environ.get("HYDRA_RUN_DIR_NAME", None) is not None) and torch.distributed.is_torchelastic_launched():
            config.trainer.restart_on_failure = True
            rprint(f"Setting restart_on_failure to True")

        if config.trainer.restart_on_failure and config.mode == 'train':
            if os.environ.get("HYDRA_RUN_DIR", None) is None and os.environ.get("HYDRA_RUN_DIR_NAME", None) is None:
                os.environ["HYDRA_RUN_DIR"] = config.output_dir
                rprint(f"Setting HYDRA_RUN_DIR to {os.environ['HYDRA_RUN_DIR']}")
            else:
                rprint(f"Not setting HYDRA_RUN_DIR, already set to {os.environ.get('HYDRA_RUN_DIR', 'N/A')}, and HYDRA_RUN_DIR_NAME is set to {os.environ.get('HYDRA_RUN_DIR_NAME', 'N/A')}")

            os.environ["RESTART_FAULT_TOLERANT"] = "1"
            rprint(f"Setting RESTART_FAULT_TOLERANT to 1")
        elif config.trainer.restart_on_failure:
            rprint(f"Restart_on_failure is True, but mode is not 'train', so not setting restart fault tolerant")

        relevant_vars = {}
        for key, value in os.environ.items():
            if "SLURM" in key or "NCCL" in key or "TORCH" in key:
                relevant_vars[key] = value

        config.env_vars = relevant_vars

        if config.trainer.profile_memory:
            config.trainer.max_steps = 2

        if config.debug and config.trainer.force_enable_checkpointing is False and (config.trainer.ckpt_steps is None or config.trainer.ckpt_steps > 0):
            config.trainer.ckpt_steps = 10000
            rprint(f"Only checkpointing every {config.trainer.ckpt_steps} steps in debug mode")

        if config.loader.global_batch_size is None:
            config.loader.global_batch_size = config.loader.batch_size * config.trainer.accumulate_grad_batches * (1 if is_torch_xla_available() else get_world_size())
            config.loader.eval_global_batch_size = config.loader.global_batch_size
            if config.trainer.scale_lr_by_batch_size:
                config.optim.lr = config.optim.lr * (config.loader.global_batch_size / 512)
            rprint(f"Setting global batch size to {config.loader.global_batch_size}, lr to {config.optim.lr}")

        if config.mode != 'train':
            config.checkpointing.resume_wandb = False
            config.wandb.resume = None

        if config.trainer.use_spmd_distributed_checkpointing is None:
            config.trainer.use_spmd_distributed_checkpointing = is_torch_xla_available() and config.trainer.xla_spmd

        if config.trainer.disable_all_eval_generation:
            config.eval.num_masking_viz_batches=0
            config.eval.num_uncond_sample_batches=0
            config.eval.num_sample_batches=0
            config.eval.num_random_masking=0
            config.eval.generate_samples=False
            config.trainer.log_flops=False
            config.eval.log_every_n_evals=-1
            config.eval.log_every_n_fid = -1
            config.model.image_model_fid_eval = False
            rprint("Disabling all eval generation!!!")

        if os.environ.get("XLA_IR_DEBUG", "0") == "1":
            config.trainer.tpu_profile = True

        if config.checkpointing_root_dir is not None:
            assert "checkpoints" in config.checkpointing.save_dir
            relative_path = Path(*Path(config.checkpointing.save_dir).relative_to(config.root_output_dir).parts[1:])
            full_checkpointing_dir = Path(config.checkpointing_root_dir) / relative_path
            if config.checkpointing_root_dir is not None:
                old_save_dir = Path(config.output_dir) / "checkpoints"
                full_checkpointing_dir.mkdir(parents=True, exist_ok=True)
                try:
                    if old_save_dir.exists():
                        rprint(f"WARNING: Cannot create symlink from {old_save_dir} to {full_checkpointing_dir} because {old_save_dir} exists.")
                        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
                        old_save_dir = Path(*old_save_dir.parts[:-1]) / f"checkpoints_{timestamp}"
                    
                    old_save_dir.symlink_to(full_checkpointing_dir, target_is_directory=True)
                    rprint(f"Created softlink from {old_save_dir} to {full_checkpointing_dir}")

                    # Create a symlink from the parent of full_checkpointing_dir named "original" back to config.output_dir
                    original_link = full_checkpointing_dir.parent / "original_output_dir"
                    if not original_link.exists():
                        original_link.symlink_to(Path(config.output_dir).resolve(), target_is_directory=True)
                        rprint(f"Created softlink from {original_link} to {config.output_dir}")
                    else:
                        rprint(f"WARNING: Symlink {original_link} already exists. Skipping creation.")

                except OSError as e:
                    rprint(f"Error creating softlinks: {e}")

        assert getattr(config.data, "allow_label", False) == getattr(config.trainer, "add_label", False) == (getattr(config.model, "add_labels", None) is not None) == getattr(config.eval, "class_conditional_fid", False), f"Mismatching values: data.allow_label={config.data.allow_label}, trainer.add_label={config.trainer.add_label}, model.add_labels={config.model.add_labels}, eval.class_conditional_fid={config.eval.class_conditional_fid}"

        if getattr(config.loader, "num_eval_workers", None) is not None and config.loader.num_workers == 0:
            rprint(f"Setting num_eval_workers to 0 because num_workers is 0")
            config.loader.num_eval_workers = 0

    if config.trainer.disable_all_checkpointing:
        gprint("-"*50)
        gprint(f"WARNING: DISABLING ALL CHECKPOINTING!!!!")
        gprint("-"*50)
        gprint(f"WARNING: DISABLING ALL CHECKPOINTING!!!!")
        gprint("-"*50)
        config.trainer.ckpt_steps = 100000000

    if config.sampling.steps != config.sampling.max_sampling_steps:
        rprint(f"WARNING!!!! steps {config.sampling.steps} != max_sampling_steps {config.sampling.max_sampling_steps}")
        config.sampling.max_sampling_steps = config.sampling.steps

def get_latest_ckpt(config, input_dir):
    if input_dir is None or not Path(input_dir).exists():
        rprint(f"Project dir {input_dir} does not exist")
        return None
    
    if config.trainer.xla_spmd and is_torch_xla_available():
        rprint(f"XLA SPMD detected, using XLA checkpointing")
        if any(Path(input_dir).iterdir()):
            rprint(f"Found existing files/folders in {input_dir}")
            return input_dir
        else:
            rprint(f"No folders found in {input_dir}")
            return None

    folders = [str(folder) for folder in Path(input_dir).iterdir() if folder.is_dir() and ((folder / "model.safetensors").exists() or (folder / "config.yaml").exists())]

    if len(folders) == 0:
        rprint(f"No folders found in {input_dir}")
        return None

    def _inner(folder):
        return list(map(int, re.findall(r"[\/]?([0-9]+)(?=[^\/]*$)", folder)))[0]

    folders.sort(key=_inner)
    rprint(f"Found folders: {folders}")
    input_dir = folders[-1]
    return input_dir

def is_sweep():
    try:
        subdir = HydraConfig.get().sweep.subdir
        rprint(f"Found sweep subdir: {subdir}")
        return True
    except omegaconf.errors.InterpolationToMissingValueError:
        return False
    
def get_sweep_run_name(config):
    try:
        subdir = HydraConfig.get().sweep.subdir
        sweep_str = f"{subdir}_"
        is_sweep = True
    except omegaconf.errors.InterpolationToMissingValueError:
        is_sweep = False
        sweep_str = f"{os.environ.get('SLURM_JOB_ID', '')}_"

    if getattr(config, "training", None) is not None and getattr(getattr(config, "training", None), "force_keys", None) is not None:
        rprint("Using legacy keys")
        forced_keys = set(config.training.force_keys)
    else:
        forced_keys = set(getattr(config.trainer, "forced_keys", []))

    if is_sweep:
        print(
            f"Getting sweep keys: {HydraConfig.get().job.sweep_keys}, Tasks: {HydraConfig.get().overrides.task}, {getattr(config.trainer, 'forced_keys', [])}"
        )
        valid_keys = set(HydraConfig.get().job.sweep_keys)
        for task in HydraConfig.get().overrides.task:
            if task.removeprefix("+").split("=")[0] in valid_keys or task.removeprefix("+").split("=")[0] in forced_keys:
                sweep_str += f"{task.removeprefix('+').split('=')[0].split('.')[-1]}={task.removeprefix('+').split('=')[1]}__"
                if task.removeprefix("+").split("=")[0] in forced_keys:
                    forced_keys.remove(task.removeprefix("+").split("=")[0])
                    print(f"Forced key: {task.removeprefix('+').split('=')[0]}={task.removeprefix('+').split('=')[1]}")

    for key in sorted(list(forced_keys)):
        sweep_str += f"{key.split('.')[-1]}={OmegaConf.select(config, key)}__"

    rprint(f"Sweep: {is_sweep=}, {sweep_str=}")
    return "" if sweep_str == "" else sweep_str[:-2]

def save_config_to_ckpt(config, output_dir, model):
    with try_except(write_error_to_file=True, clear_cuda_cache=True):
        with read_write(config):
            with open_dict(config):
                config.state.ckpt_step = model.global_step
                config.state.num_evals = model.num_evals

        OmegaConf.save(config=config, f=Path(output_dir) / "config.yaml")
        rprint(f"Saved global step {model.global_step}")

def determine_ckpt(config):
    has_recent_ckpt = False
    rprint(f"Looking at checkpoint path: {getattr(config.checkpointing, 'resume_ckpt_path', None)}")
    if (
        config.checkpointing.resume_from_ckpt
        and (latest_ckpt := get_latest_ckpt(config, getattr(config.checkpointing, "resume_ckpt_path", None))) is not None
        and (Path(latest_ckpt) / "config.yaml").exists()
    ):
        ckpt_path = latest_ckpt
        has_recent_ckpt = True
        if config.slurm:
            config.wandb.resume = "must"
        rprint(f"Resuming from checkpoint {ckpt_path}")
    elif config.checkpointing.resume_from_ckpt and getattr(config.checkpointing, "initial_resume_ckpt_path", None) is not None:
        ckpt_path = config.checkpointing.initial_resume_ckpt_path
        rprint(f"Resuming from initial checkpoint {ckpt_path}")
    else:
        ckpt_path = None

    if ckpt_path is not None and (config.checkpointing.resume_wandb or has_recent_ckpt):
        loaded = OmegaConf.load(Path(ckpt_path) / "config.yaml")
        if loaded.wandb.id is not None:
            config.wandb.id = str(loaded.wandb.id)
            rprint(f"Found wandb id: {config.wandb.id}")
        else:
            rprint(f"No wandb id found in checkpoint {ckpt_path}")

    if config.checkpointing.resume_wandb and config.wandb.id is not None:
        config.wandb.resume = "must"
        rprint(f"Resuming wandb, setting must, run id: {config.wandb.id}")
    elif config.slurm and config.wandb.id is None:
        if os.environ.get("SLURM_ARRAY_TASK_COUNT", "") != "" and int(os.environ.get("SLURM_ARRAY_TASK_COUNT", "")) > 1:
            config.wandb.id = str(os.environ.get("SLURM_ARRAY_JOB_ID")) + f"_{os.environ.get('SLURM_ARRAY_TASK_ID')}"
        else:
            config.wandb.id = str(os.environ.get("SLURM_JOB_ID"))
        rprint(f"Setting wandb id to {config.wandb.id}")

    if config.checkpointing.initial_resume_ckpt_path is not None and config.checkpointing.resume_wandb:
        assert config.wandb.id is not None

    if config.ckpt is not None:
        ckpt_path = config.ckpt
        rprint(f"Running eval with checkpoint {ckpt_path}")

    if config.wandb.id is not None:
        config.wandb.id = str(config.wandb.id)

    if config.wandb.id is None or getattr(config.trainer, "force_new_wandb_id", False):
        config.wandb.id = wandb.util.generate_id()
        config.wandb.resume = "allow"
        rprint(f"Set wandb id: {config.wandb.id}")

    rprint(f"Using wandb id: {config.wandb.id}")
    subdir = get_sweep_run_name(config)
    rprint(f"Wandb name: {config.wandb.name}, Wandb subdir: {subdir}")

    if config.wandb.name == 'default':
        config.wandb.name = None
    else:
        config.wandb.name = (
            (f"{config.wandb.name}_" if config.wandb.name else "")
            + (f"{subdir}_" if (subdir is not None and subdir != "") else "")
            + f"{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}"
        )

    if getattr(config.wandb, "group", None) is None and subdir is not None and config.debug and os.environ.get("SLURM_ARRAY_JOB_ID", "") != "":
        config.wandb.group = os.environ.get("SLURM_ARRAY_JOB_ID")
        rprint(f"Wandb group: {config.wandb.group}")

    return ckpt_path

def run(config, tokenizer):
    import torch
    from accelerate import (Accelerator, DataLoaderConfiguration,
                            DDPCommunicationHookType,
                            DistributedDataParallelKwargs,
                            FullyShardedDataParallelPlugin)
    from accelerate.state import AcceleratorState
    from accelerate.utils import GradientAccumulationPlugin, ProjectConfiguration

    set_torch_defaults(config.trainer.benchmark)

    set_env_vars(config)
    update_config_before_resolution(config)
    ckpt_path = determine_ckpt(config)
    OmegaConf.resolve(config)
    if is_torch_cuda_available():
        check_gpu_memory_usage()

    if is_torch_cuda_available():
        rprint(f"pt={torch.__version__}, cuda={torch.version.cuda}, nccl={torch.cuda.nccl.version()}")
        rprint(f"GPU={torch.cuda.get_device_name()}, device compute capabilities={torch.cuda.get_device_capability()}, pytorch compute capabilities={torch.cuda.get_arch_list()}")
    elif is_torch_xla_available():
        rprint(f"XLA Devices={get_tpu_devices()}")

    rprint(
        f"Initial GROUP_RANK: {os.environ.get('GROUP_RANK', 'N/A')}, RANK: {os.environ.get('RANK', 'N/A')}, LOCAL_RANK: {os.environ.get('LOCAL_RANK', 'N/A')}, WORLD_SIZE: {os.environ.get('WORLD_SIZE', 'N/A')}, MASTER_ADDR: {os.environ.get('MASTER_ADDR', 'N/A')}, MASTER_PORT: {os.environ.get('MASTER_PORT', 'N/A')}, TORCHELASTIC_RUN_ID: {os.environ.get('TORCHELASTIC_RUN_ID', 'N/A')}, TORCHELASTIC_RESTART_COUNT: {os.environ.get('TORCHELASTIC_RESTART_COUNT', 'N/A')}, TORCHELASTIC_MAX_RESTARTS: {os.environ.get('TORCHELASTIC_MAX_RESTARTS', 'N/A')}, LOCAL_WORLD_SIZE: {os.environ.get('LOCAL_WORLD_SIZE', 'N/A')}, Elastic: {torch.distributed.is_torchelastic_launched()}"
    )
    rprint(f"Computed Rank: {get_rank()}, Local Rank: {get_local_rank()}, World Size: {get_world_size()}")

    # This lets us have start_timing and end_timing functions and a global enable/disable
    # We always use torch.cuda.synchronize before/after as otherwise the timing is not very meaningful
    sync_timing = (config.trainer.nvtx_profile and getattr(config.trainer, "sync_nvtx_timing", True)) or getattr(config.trainer, "sync_timing", False)
    set_timing_builtins(enable=config.trainer.nvtx_profile, sync=sync_timing)

    num_nodes = config.trainer.num_nodes
    with open_dict(config):
        config.trainer = OmegaConf.merge(config.trainer, dict(mixed_precision=config.trainer.precision, log_with="wandb", log_gradients=None))
    if getattr(config.trainer, "process_dataloader_only", False):
        gprint("Processing dataloader only")
        train_ds, valid_ds = dataloader.get_dataloaders(config, tokenizer, device="cpu", skip_train=(config.mode == 'eval' and not config.eval.val_with_train_data))
        gprint(f"Exiting after processing dataloader")
        return

    accelerator_project_config = ProjectConfiguration(
        project_dir=config.output_dir,
        logging_dir=config.logging_dir,
        automatic_checkpoint_naming=config.checkpointing.use_automatic_naming,
        save_on_each_node=False,
    )

    accelerate_kwargs = dict()
    gradient_kwargs = dict()
    if config.trainer.fsdp and not (config.trainer.xla_spmd and is_torch_xla_available()):
        rprint("Using FSDP...")
        if config.backbone == "llama" or config.backbone == "gemma":
            os.environ["ACCELERATE_USE_FSDP"] = "true"
            os.environ["FSDP_AUTO_WRAP_POLICY"] = "TRANSFORMER_BASED_WRAP"
            os.environ["FSDP_BACKWARD_PREFETCH"] = "NO_PREFETCH" # Saved memory
            os.environ["FSDP_CPU_RAM_EFFICIENT_LOADING"] = "true"
            os.environ["FSDP_FORWARD_PREFETCH"] = "false"
            os.environ["FSDP_OFFLOAD_PARAMS"] = "false"
            os.environ["FSDP_SHARDING_STRATEGY"] = "FULL_SHARD"
            os.environ["FSDP_STATE_DICT_TYPE"] = "SHARDED_STATE_DICT"
            os.environ["FSDP_SYNC_MODULE_STATES"] = "true"
            os.environ["FSDP_USE_ORIG_PARAMS"] = "true"
            fsdp_plugin = FullyShardedDataParallelPlugin()
        else:
            os.environ["ACCELERATE_USE_FSDP"] = "true"
            os.environ["FSDP_AUTO_WRAP_POLICY"] = "TRANSFORMER_BASED_WRAP"  # or "SIZE_BASED_WRAP"
            if config.backbone == "elm":
                os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = "OpenELMDecoderLayer"
                os.environ["FSDP_BACKWARD_PREFETCH"] = "BACKWARD_PRE" 
                os.environ["FSDP_SHARDING_STRATEGY"] = "HYBRID_SHARD_ZERO2"
            else:
                # Fastest but requires more memory: https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.BackwardPrefetch
                os.environ["FSDP_BACKWARD_PREFETCH"] = "BACKWARD_PRE" 
                # See: https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.ShardingStrategy
                os.environ["FSDP_SHARDING_STRATEGY"] = "HYBRID_SHARD_ZERO2"
                os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = "DDiTBlock" 

            # SHARDED_STATE_DICT is a bit faster, but more complicated as later on we need to merge the shards.
            from torch.distributed.fsdp.fully_sharded_data_parallel import (
                FullOptimStateDictConfig, FullStateDictConfig)
            fsdp_plugin = FullyShardedDataParallelPlugin(
                state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
                optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), # SHARDED_STATE_DICT
            )

        if config.trainer.compile or config.trainer.use_orig_params is True:
            # https://github.com/huggingface/transformers/pull/24591/files
            fsdp_plugin.use_orig_params = True
            rprint("Using orig params for FSDP. This is required for torch.compile to work.")

        accelerate_kwargs["fsdp_plugin"] = fsdp_plugin
        gradient_kwargs["sync_each_batch"] = False

        if getattr(config.trainer, "fsdp_sync_each_batch", False): # Reduce memory usage: https://huggingface.co/docs/accelerate/en/concept_guides/gradient_synchronization#nosync-requires-additional-gpu-memory-when-using-fsdp
            rprint("Using sync each batch for Chameleon")
            gradient_kwargs["sync_each_batch"] = True

    elif config.trainer.xla_spmd is False: # For XLA FSDP, we init where we normally prepare()
        rprint("Using DDP...")
        ddp_kwargs = DistributedDataParallelKwargs(
            find_unused_parameters=config.trainer.find_unused_parameters,
            comm_hook=DDPCommunicationHookType.BF16,
            static_graph=config.trainer.accumulate_grad_batches == 1,
            gradient_as_bucket_view=True,
        )
        # bucket_cap_mb=32,

        # Not needed right now
        from datetime import timedelta

        from accelerate.utils import InitProcessGroupKwargs
        init_process_group_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=1800))
        accelerate_kwargs["kwargs_handlers"] = [ddp_kwargs, init_process_group_kwargs]
    else:
        rprint(f"Did not choose DDP or FSDP.")

    if config.trainer.accumulate_grad_batches <= 0:
        gprint("WARNING!!!!!! Accumulate grad batches is <= 0, setting to 1")
        config.trainer.accumulate_grad_batches = 1

    gradient_accumulation_plugin = GradientAccumulationPlugin(
        num_steps=config.trainer.accumulate_grad_batches,
        adjust_scheduler=False, # We manually adjust our LR for accumulate_grad_batches
        sync_with_dataloader=False,
        **gradient_kwargs
    )

    if config.trainer.mixed_precision == "bf16" and (is_torch_cuda_available() and not torch.cuda.is_bf16_supported()):
        rprint(f"No BF16 GPU found, falling back to FP16")
        config.trainer.mixed_precision = "fp16"

    if config.trainer.mixed_precision == "fp32":
        config.trainer.mixed_precision = "no"
    else:
        if is_torch_xla_available():
            os.environ["ACCELERATE_DOWNCAST_BF16"] = "true"

    rprint(f"Mixed precision: {config.trainer.mixed_precision}")

    if config.seed is None or getattr(config.eval, 'set_random_gen_seed', False):
         # do not ask why, has to do something with seeds being reset by val_epoch_end so if you don't execute this code, your generations in val_epoch_end will be same across gpus
        accelerate_kwargs["rng_types"] = []
        rprint("No seed provided, disabling accelerate RNG synchronization")

    accelerator = Accelerator(
        mixed_precision=config.trainer.mixed_precision,
        log_with=config.trainer.log_with,
        project_config=accelerator_project_config,
        gradient_accumulation_plugin=gradient_accumulation_plugin,
        dataloader_config=DataLoaderConfiguration(split_batches=False, dispatch_batches=False, non_blocking=False),
        **accelerate_kwargs,
    )

    gprint(f"Distributed Type: {accelerator.distributed_type}, Accelerator state: {AcceleratorState()}")
    num_processes = AcceleratorState().num_processes
    if getattr(config.trainer, "global_num_warmup_steps", None) is not None:
        rprint(f"Global num_warmup_steps was: {config.lr_scheduler.num_warmup_steps}. Applying to num_warmup_steps")
        config.lr_scheduler.num_warmup_steps = config.trainer.global_num_warmup_steps

    if getattr(config.trainer, "global_num_training_steps", None) is not None:
        rprint(f"Global num_training_steps was: {config.lr_scheduler.num_training_steps}. Applying to num_training_steps")
        config.lr_scheduler.num_training_steps = config.trainer.global_num_training_steps

    if not config.trainer.disable_adjust_num_warmup_steps:
        rprint(f"Original num_warmup_steps was: {config.lr_scheduler.num_warmup_steps}")
        config.lr_scheduler.num_warmup_steps = config.lr_scheduler.num_warmup_steps * num_processes
        rprint(f"Setting num_warmup_steps to: {config.lr_scheduler.num_warmup_steps}")

        if hasattr(config.lr_scheduler, "num_training_steps"):
            rprint(f"Original num_training_steps was: {config.lr_scheduler.num_training_steps}")
            config.lr_scheduler.num_training_steps = config.lr_scheduler.num_training_steps * num_processes
            rprint(f"Setting num_training_steps to: {config.lr_scheduler.num_training_steps}")

    assert config.trainer.allow_dynamic_nodes or (os.environ.get("XLA_USE_SPMD", "0") == "1") or accelerator.num_processes == (
        config.trainer.devices * num_nodes
    ), f"Expected {config.trainer.devices * num_nodes} GPUs but got {accelerator.num_processes} processes."

    compute_dtyle = torch.float32
    if accelerator.mixed_precision == "fp16":
        compute_dtyle = torch.float16
    elif accelerator.mixed_precision == "bf16":
        compute_dtyle = torch.bfloat16

    if compute_dtyle != torch.bfloat16:
        rprint(f"WARNING!!!! Compute dtype is: {compute_dtyle}")
    else:
        rprint(f"Compute dtype is: {compute_dtyle}")

    if is_main_process():
        instantiate_wandb(config, accelerator)

    run_cmd = get_run_cmd(config)
    with open_dict(config):
        config.trainer.devices = accelerator.num_processes
        config.trainer.dtype = str(compute_dtyle)
        if hasattr(config, "state"):
            config.state.cmd = run_cmd
        else:
            config.state = OmegaConf.create(dict(cmd=run_cmd))

    OmegaConf.set_readonly(config, True)

    if getattr(config.trainer, "attach_oom_observer", False):
        from torchtnt.utils.oom import attach_oom_observer
        attach_oom_observer(output_dir=str(os.getcwd()), trace_max_entries=500000)
        rprint(f"Attached OOM observer to {os.getcwd()}")
    train_ds, valid_ds = dataloader.get_dataloaders(config, tokenizer, device=accelerator.device, skip_train=(config.mode == 'eval' and not config.eval.val_with_train_data))
    model = Diffusion(config=config, tokenizer=valid_ds.tokenizer, device=accelerator.device)

    if is_main_process():
        print_params(model.backbone)

    try:
        if getattr(config.model, "image_model", False) is False:
            _print_batch(train_ds, valid_ds, tokenizer)
    except:
        pass

    get_ema_path = lambda x: Path(x) / "ema.ckpt"
    SAMPLER_NAME = "weighted_dataset_sampler"

    def save_model_hook(models, weights, output_dir):
        nonlocal model, accelerator, train_ds

        if is_main_process():
            with try_except(write_error_to_file=True):
                if getattr(model, "ema", None) is not None:
                    torch.save(accelerator.unwrap_model(model).ema.state_dict(), get_ema_path(output_dir))
                    rprint(f"Saved EMA to {get_ema_path(output_dir)}")

            save_config_to_ckpt(config, output_dir, model)

            with try_except(write_error_to_file=True):
                if config.data.use_weighted_tensordict_sampler:
                    from accelerate.utils import save
                    output_sampler_file = output_dir.joinpath(f"{SAMPLER_NAME}_train.bin")
                    save(train_ds.sampler.state_dict(), output_sampler_file, save_on_each_node=False, safe_serialization=False)
                    rprint(f"Sampler state for dataloader saved in {output_sampler_file}")

    initial_global_step = None
    def load_model_hook(models, input_dir):
        nonlocal initial_global_step, model, train_ds
        config_path = os.path.join(input_dir, "config.yaml")
        ckpt_config = OmegaConf.load(config_path)
        initial_global_step = ckpt_config.state.ckpt_step
        model.global_step = initial_global_step
        try:
            if hasattr(config.state, "num_evals"):
                model.num_evals = config.state.num_evals
        except Exception as e:
            rprint(f"Error loading model: {e}")
        rprint(f"Loaded global step {initial_global_step}")

        state_dict = None
        if getattr(config.checkpointing, "load_from_old_attention_format", False):
            state_dict = load_file(os.path.join(input_dir, "model.safetensors"))
            state_dict = convert_state_dict_keys(state_dict)

        if getattr(model, "ema", None) is not None:
            if get_ema_path(input_dir).exists():
                rprint(f"Loading EMA from {get_ema_path(input_dir)}")
                model.ema.load_state_dict(torch.load(get_ema_path(input_dir), map_location='cpu'))
            else:
                rprint(f"No EMA found, initializing EMA with state_dict")
                if state_dict is None:
                    state_dict = load_file(os.path.join(input_dir, "model.safetensors"))

                # We likely don't need the unwrap, but just to be safe
                accelerator.unwrap_model(models[0]).load_state_dict(state_dict)
                from models.ema import EMAModel
                model.ema = EMAModel(accelerator.unwrap_model(models[0]).parameters(), decay=config.trainer.ema)

        if config.data.use_weighted_tensordict_sampler and not is_torch_xla_available(): # and not config.eval.test_eval_speed:
            input_sampler_file = Path(input_dir).joinpath(f"{SAMPLER_NAME}_train.bin")
            if train_ds is not None and input_sampler_file.exists():
                train_ds.sampler.load_state_dict(torch.load(input_sampler_file))
            rprint("All dataloader sampler states loaded successfully")

    accelerator.register_save_state_pre_hook(save_model_hook)
    accelerator.register_load_state_pre_hook(load_model_hook)
    model.init_dataloader(train_ds, valid_ds)
    model.set_accelerator(accelerator, ckpt_path)
    model.set_callbacks()

    if getattr(config.checkpointing, "load_from_text_model", None) is not None:
        rprint(f"Loading from text model")
        model.custom_load_checkpoint()

    if getattr(config.checkpointing, "load_from_lightning_ckpt", None) is not None:
        ckpt = torch.load(config.checkpointing.load_from_lightning_ckpt)
        initial_global_step = ckpt["global_step"]
        state_dict_ = {k.removeprefix("backbone."): v for k, v in ckpt["state_dict"].items() if "backbone" in k}
        state_dict_ = {k.replace(".attn_", ".attention.attn_"): v for k, v in state_dict_.items()}
        accelerator.unwrap_model(model.backbone).load_state_dict(state_dict_)

        if config.trainer.ema > 0:
            model.ema.load_state_dict(ckpt["ema"])

        rprint(f"Loaded lightning ckpt: {config.checkpointing.load_from_lightning_ckpt}")

    if initial_global_step is not None:
        # The load_hooks are before accelerate does it's loading and it overwrites model.global_step if we set it there
        model.global_step = initial_global_step
        rprint(f"Set global step to {initial_global_step}")

    contexts = []
    if config.trainer.nvtx_profile:
        contexts.append(torch.autograd.profiler.emit_nvtx(record_shapes=True))

    if config.trainer.profile_memory:
        contexts.append(profile_memory())

    using_torch_elastic = torch.distributed.is_torchelastic_launched()
    if using_torch_elastic:
        rprint(f"Torchelastic launched: {torch.distributed.is_torchelastic_launched()}")
        contexts.append(ErrorHandler())

    with ExitStack() as stack:
        for ctx in contexts:
            stack.enter_context(ctx)

        rprint(f"output_dir: {config.output_dir}")
        model.to(accelerator.device)
        if config.mode == 'train':
            model.train()
        elif config.mode == 'eval':
            if config.eval.standalone_fid:
                model.validate(None)
            else:
                model.validate(None)
        elif config.mode == 'zero-shot-eval':
            model.zero_shot_eval()
        else:
            raise ValueError(f"Invalid mode: {config.mode}")

    accelerator.end_training()


def get_run_cmd(config):
    orig_argv = deepcopy(sys.argv)

    prepend_argv = []
    if "HYDRA_RUN_DIR" in os.environ:
        prepend_argv.append(f"HYDRA_RUN_DIR='{os.environ['HYDRA_RUN_DIR']}'")
    else:
        prepend_argv.append(f"HYDRA_RUN_DIR='{str(Path(config.output_dir).resolve())}'")

    if orig_argv[1].startswith("experiments=["):
        orig_argv[1] = orig_argv[1].removeprefix("experiments=[").removesuffix("]")
        orig_argv[1] = f"experiments=\'[{orig_argv[1]}]\'"

    if os.environ.get("CUSTOM_SBATCH_LAUNCHER", "0") == "1":
        sbatch_script_path = 'scripts/slurm.sh'
        orig_argv.pop(0)
        orig_argv = ["sbatch", f"--nodes={os.environ.get('SLURM_NNODES', '1')}", f"--gpus-per-node={os.environ.get('SLURM_GPUS_PER_NODE', '1')}", f"--partition={os.environ.get('SLURM_JOB_PARTITION', 'all')}", sbatch_script_path] + orig_argv
    else:
        prepend_argv.append("accelerate launch")

    full_run_cmd = " ".join(prepend_argv + orig_argv)
    rprint(f"Full run cmd: {full_run_cmd}")
    return full_run_cmd

@hydra.main(version_base=None, config_path=CONFIG_PATH, config_name="config")
@try_except()
def main(config):
    if is_sweep():
        print(f"Checking if we need to requeue for job id {os.environ['SLURM_JOB_ID']}")
        from unidisc.utils.slurm_requeue import check_requeue
        check_requeue()
        print(f"Done checking if we need to requeue for job id {os.environ['SLURM_JOB_ID']}.")

    """Main entry point for trainer."""
    import torch  # Causes issue pickling if imported by default.
    if is_torch_xla_available():
        builtins.HAS_XLA_SPAWNED = True
        os.environ['PJRT_DEVICE'] = 'TPU'

        if config.trainer.precision == "bf16":
            os.environ['XLA_USE_BF16'] = '1'

        if config.devices == 1 and config.trainer.xla_spmd is False and config.trainer.fsdp is False:
            os.environ['TPU_PROCESS_BOUNDS'] = '1,1,1'
            os.environ['TPU_VISIBLE_CHIPS'] = '0'
            gprint(f"Setting TPU_PROCESS_BOUNDS: {os.environ['TPU_PROCESS_BOUNDS']}")
            gprint(f"Setting TPU_VISIBLE_CHIPS: {os.environ['TPU_VISIBLE_CHIPS']}")

        if config.trainer.tpu_eager:
            os.environ['XLA_USE_EAGER_DEBUG_MODE'] = '1'

        if config.trainer.tpu_compile_debug:
            os.environ['PT_XLA_DEBUG'] = '1'
            os.environ['PT_XLA_DEBUG_LEVEL'] = '2'
            os.environ['XLA_IR_DEBUG'] = '1'
            os.environ['XLA_HLO_DEBUG'] = '1'
            os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
            os.environ['TF_CPP_VMODULE'] = 'xla_graph_executor=5,pjrt_computation_client=3'

        # We intentionally set these after to avoid import side effects
        spmd_mesh, axis_names, num_nodes = None, None, None
        if config.trainer.xla_spmd:
            import torch_xla.core.xla_model as xm
            import torch_xla.distributed.spmd as xs
            import torch_xla.runtime as xr
            from accelerate import PartialState
            from torch_xla._internal import tpu
            auto_spmd = getattr(config.trainer, "auto_spmd", False)

            xr.use_spmd(auto=auto_spmd) # Auto causes a crash
            force_global_devices = getattr(config.trainer, "force_global_devices", None)
            force_local_devices = getattr(config.trainer, "force_local_devices", None)
            assert (force_global_devices is None) == (force_local_devices is None), "Must set both or neither"

            if force_global_devices is not None:
                num_global_devices = force_global_devices
                num_local_devices = force_local_devices
                gprint(f"Using force global devices: num_global_devices={num_global_devices}, num_local_devices={num_local_devices}")
            else:
                num_global_devices = xr.global_runtime_device_count()
                num_local_devices = tpu.num_available_devices()
                assert num_global_devices == tpu.num_expected_global_devices()
                assert tpu.num_available_devices() == tpu.num_available_chips() == tpu.num_local_processes()

            num_nodes = num_global_devices // num_local_devices
            spmd_mesh_shape = getattr(config.trainer, "spmd_mesh", None)
            if spmd_mesh_shape is None:
                spmd_mesh_shape = (num_nodes, num_local_devices, 1)

            if getattr(config.trainer, "force_disable_replicas", False):
                spmd_mesh_shape = (1, num_global_devices, 1)
                rprint(f"Forcing disable replicas: {spmd_mesh_shape}")

            if auto_spmd is False:
                if getattr(config.trainer, "spmd_multislice", None) is not None:
                    from torch_xla.distributed.spmd import HybridMesh
                    ici_mesh_shape = spmd_mesh_shape
                    dcn_mesh_shape = (config.trainer.spmd_multislice, 1, 1)
                    spmd_mesh = HybridMesh(ici_mesh_shape=ici_mesh_shape, dcn_mesh_shape=dcn_mesh_shape, axis_names=('data','fsdp','tensor'))
                    rprint(f"Using multislice: {config.trainer.spmd_multislice}: {ici_mesh_shape} {dcn_mesh_shape}, {spmd_mesh.shape()}")
                else:
                    spmd_mesh = xs.Mesh(np.array(range(num_global_devices)), spmd_mesh_shape, ('dcn', 'fsdp', 'model'))
                xs.set_global_mesh(spmd_mesh)

            config.devices = 1
            config.nodes = 1

            with read_write(config):
                with open_dict(config):
                    config.state = OmegaConf.create(dict(spmd_mesh=spmd_mesh_shape))
                    config.state.axis_names = axis_names
                    config.state.num_nodes = num_nodes
                    config.state.num_global_devices = num_global_devices
                    config.state.num_local_devices = num_local_devices
                    config.state.worker_ips = tpu.get_worker_ips()
                    if os.environ.get("TPU_NAME") is not None:
                        config.state.tpu_name = os.environ.get("TPU_NAME")

        if config.trainer.tpu_eager:
            import torch_xla
            torch_xla.experimental.eager_mode(True)

        if config.trainer.tpu_profile:
            if config.trainer.tpu_profile_markers:
                os.environ['XLA_IR_DEBUG'] = '1'
                os.environ['XLA_HLO_DEBUG'] = '1'
            import torch_xla.debug.profiler as xp
            server = xp.start_server(9012)

        if config.trainer.tpu_cache:
            import torch_xla.runtime as xr
            readonly = not is_main_process()
            rprint(f"Initializing TPU cache with readonly={readonly}")
            xr.initialize_cache(str((Path.home() / '.cache' / 'unidisc' / f"tpu_{get_rank()}_{get_hostname().replace('-', '_')}").resolve()), readonly=readonly)

        if config.trainer.enable_jax_smi:
            initialise_tracking()
            rprint("Initializing jax-smi")

    from unidisc.utils.logging_utils import set_logger
    set_logger(f"{__name__} {get_slurm_log_prefix()}", Path(f"{get_slurm_filename_info()}_{get_hostname().replace('-', '_')}.out"))

    if is_torch_xla_available():
        import torch_xla.runtime as xr
        gprint(
                f"Computed Rank: {get_rank()}, "
                f"Is Main Process: {is_main_process()}, "
                f"Is Local Main Process: {is_local_main_process()}, "
                f"XLA world size: {xr.world_size()}, "
                f"XLA Model Ordinal: {xm.get_ordinal()}, "
                f"XLA Global Ordinal: {xr.global_ordinal()}, "
                f"XLA Supported Devices: {xm.get_xla_supported_devices()}, "
                f"Accelerate Local Process Index: {PartialState().local_process_index}, "
                f"Task ID: {tpu.task_id()}, "
                f"Worker ID: {tpu.worker_id()} "
                f"global device count: {xr.global_runtime_device_count()}, "
                f"local process count: {xr.local_process_count()}, "
                f"local device count: {xr.local_device_count()}, "
                f"addressable device count: {xr.addressable_device_count()}, "
                f"num_expected_global_devices: {tpu.num_expected_global_devices()}, "
                f"num_available_devices: {tpu.num_available_devices()}, "
                f"num_available_chips: {tpu.num_available_chips()}, "
                f"num_local_processes: {tpu.num_local_processes()}, "
                f"process_bounds_size: {tpu.process_bounds_size()}, "
                f"get_worker_ips: {tpu.get_worker_ips()}, "
                f"Computed Num Nodes: {num_nodes}, "
                f"Specified Mesh: {spmd_mesh_shape}, "
                f"Specified Mesh Axes: {axis_names}"
            )

        gprint(f"LIBTPU_INIT_ARGS: {os.environ.get('LIBTPU_INIT_ARGS', 'None')}")
        gprint(f"XLA_FLAGS: {os.environ.get('XLA_FLAGS', 'None')}")
    
    if getattr(config.trainer, "disable_ddp_optimizer", False):
        torch._dynamo.config.optimize_ddp = False

    if config.seed is not None:
        if config.mode == 'eval':
            config.seed = config.seed + 1000 * int(get_rank())
        else:
            config.seed = config.seed + int(get_rank())
        np.random.seed(config.seed)
        random.seed(config.seed)
        torch.manual_seed(config.seed)
        if is_torch_cuda_available():
            # TODO: Is seed all desired? Does it set the same one on all GPUs even in multi-process?
            torch.cuda.manual_seed_all(config.seed)

        if is_torch_xla_available():
            import torch_xla.core.xla_model as xm
            xm.set_rng_state(config.seed)
        gprint(f"Set seed: {config.seed}")
    else:
        rprint("No seed provided")

    _print_config(config, resolve=True, save_cfg=True)

    with open(f"env_vars_{get_slurm_filename_info()}_{get_hostname().replace('-', '_')}.txt", "w") as f:
        for key, value in os.environ.items():
            f.write(f"{key}={value}\n")

    tokenizer = dataloader.get_tokenizer(config)

    if "tokens" in config.data.train and (config.loader.num_workers > 0 or getattr(config.data, "force_mp_spawn", False)):
        from torch import multiprocessing as mp
        try:
            rprint(f"Start already method set to: {mp.get_start_method()}")
        except:
            mp.set_start_method("spawn")
            rprint(f"Start method set to: {mp.get_start_method()}")

    rprint(f"Mode: {config.mode}")
    if config.mode == "sample_eval":
        generate_samples(config, tokenizer)
    else:
        try:
            run(config, tokenizer)
        except Exception as e:
            rprint(f"Traceback: {traceback.format_exc()}")
            rprint(f"Exception: {e}")
    
            timestamp = int(__import__("time").time_ns())
            error_filepath = f"exception_{timestamp}_{process_file_prefix()}.out"
            with open(error_filepath, "w") as file:
                file.write(traceback.format_exc())
            rprint(f"See error file {Path(error_filepath).resolve()} for traceback")
            
            if is_torch_xla_available():
                exit(1)

            if ("SLURM_JOB_ID" not in os.environ) and ("RESTART_FAULT_TOLERANT" not in os.environ) and not is_torch_xla_available():
                gprint(f"Entering debugger")
                breakpoint(traceback=e.__traceback__)
            else:
                rprint(f"Not breaking, SLURM_JOB_ID: {os.environ.get('SLURM_JOB_ID')}, RESTART_FAULT_TOLERANT: {os.environ.get('RESTART_FAULT_TOLERANT')}")

            if "RESTART_FAULT_TOLERANT" in os.environ:
                sigterm_handler = signal.getsignal(signal.SIGTERM)
                if callable(sigterm_handler):
                    rprint(f"Calling SIGTERM handler")
                    sigterm_handler(signal.SIGTERM, None)

                try:
                    if config.trainer.num_nodes > 1 and config.debug is False and is_main_process():
                        wandb.alert(title="Exception!", text=f"{e}, {traceback.format_exc()}")
                except:
                    pass
            raise e
        finally:
            pass

if __name__ == "__main__":
    main()