import math
import time
import warnings
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, Dataset
from transformers import Trainer

with warnings.catch_warnings():
    warnings.simplefilter(action="ignore", category=FutureWarning)
    from transformers.deepspeed import deepspeed_init

from transformers.integrations import TensorBoardCallback
from transformers.trainer_pt_utils import (
    IterableDatasetShard,
    find_batch_size,
    nested_concat,
    nested_detach,
    nested_numpify,
)
from transformers.trainer_utils import denumpify_detensorize, has_length, speed_metrics
from transformers.utils import (
    is_apex_available,
    is_datasets_available,
    is_sagemaker_mp_enabled,
    logging,
)

from sdlm.inference.inference_utils import (
    logits_projection,
    predict_conditional_generated,
)
from sdlm.models.utils import is_cdcd_check
from sdlm.pipelines.simplex_ddpm import SimplexDDPMClassifierGuidancePipeline
from sdlm.utils import convert_to_simplex, pad_data, scale, self_condition_preds

if is_apex_available():
    from apex import amp

if is_datasets_available():
    import datasets

GENERATION_RESULTS = "generated"


logger = logging.get_logger(__name__)


class EvalLoopOutput(NamedTuple):
    logits: Union[np.ndarray, Tuple[np.ndarray]]
    simplex: Union[np.ndarray, Tuple[np.ndarray]]
    input_ids: Optional[Union[np.ndarray, Tuple[np.ndarray]]]
    metrics: Optional[Dict[str, float]]
    results: Optional[Dict[str, List[str]]]
    num_samples: Optional[int]


class DiffusionTrainer(Trainer):
    def __init__(
        self,
        noise_scheduler,
        inference_noise_schedulers,
        diffusion_args,
        data_args,
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.original_data_collator = self.data_collator
        self.noise_scheduler = noise_scheduler
        self.diffusion_args = diffusion_args
        self.data_args = data_args
        self.vocab_size = self.model.config.vocab_size
        self.inference_noise_schedulers = inference_noise_schedulers
        self.inference_timesteps = diffusion_args.num_inference_diffusion_steps
        self.tb_writer = self.get_tb_writer()
        self.eos_token_id = self.tokenizer.eos_token_id
        self.classifier_free_guidance = (
            diffusion_args.guidance_scale > 1.0
            and data_args.conditional_generation is not None
        )
        self.counter = 0
        # TODO: control seed.
        self.self_cond_generator = np.random.default_rng(42)

    def annotated_split(self, split):
        return f"{split}_top_p_{self.diffusion_args.top_p}_temperature_{self.diffusion_args.temperature}_seed_{self.args.seed}_guidance_scale_{self.diffusion_args.guidance_scale}"

    def save_metrics(self, split, metrics, combined=True):
        super().save_metrics(self.annotated_split(split), metrics, combined)

    def log_metrics(self, split, metrics):
        super().log_metrics(self.annotated_split(split), metrics)

    def get_tb_writer(self):
        for cb in self.callback_handler.callbacks:
            if isinstance(cb, TensorBoardCallback):
                return cb
        return None

    def training_step(
        self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]
    ) -> torch.Tensor:
        """
        Perform a training step on a batch of inputs.
        Subclass and override to inject custom behavior.
        Args:
            model (`nn.Module`):
                The model to train.
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.
                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument `labels`. Check your model's documentation for all accepted arguments.
        Return:
            `torch.Tensor`: The tensor with training loss on this batch.
        """
        model.train()
        inputs = self._prepare_inputs(inputs)

        # Truncate the length if needed.
        if self.data_args.truncation_length > 0:
            inputs["input_ids"] = inputs["input_ids"][
                :, : -self.data_args.truncation_length
            ]
            inputs["span_mask"] = inputs["span_mask"][
                :, : -self.data_args.truncation_length
            ]

        # Creates the noisy simplex and timesteps.
        simplex = convert_to_simplex(
            inputs["input_ids"], self.diffusion_args.simplex_value, self.vocab_size
        )
        noise = self.diffusion_args.simplex_value * torch.randn(
            simplex.shape, device=simplex.device, dtype=simplex.dtype
        )
        bsz = simplex.shape[0]
        # Sample a random timestep for each simplex token representation.
        # testing just sampling the same place. This better matches reality.
        if True:  # np.random.rand(1) > 0.5:
            timesteps = torch.randint(
                0,
                len(self.noise_scheduler),
                (bsz, inputs["input_ids"].shape[1])
                if False  # is_tokenwise_cdcd_check(self.model)
                else (bsz,),
                device=simplex.device,
                dtype=torch.int64,
            )
            timesteps = timesteps[:, None].expand(-1, inputs["input_ids"].shape[1])
        else:
            timesteps = torch.randint(
                0,
                len(self.noise_scheduler),
                (bsz, inputs["input_ids"].shape[1])
                if True  # is_tokenwise_cdcd_check(self.model)
                else (bsz,),
                device=simplex.device,
                dtype=torch.int64,
            )
        # expand out timesteps to match tokenwise setup
        # if True:  # not is_tokenwise_cdcd_check(self.model):
        #     timesteps = timesteps[:, None].expand(-1, inputs["input_ids"].shape[1])

        # save original timesteps for warping
        original_timesteps = timesteps
        # warp timesteps according to cdf
        # we re-scale the timesteps to the correct range.
        # the -1 is due to the timestep should be in range [0, 5000)
        if is_cdcd_check(self.model):
            input_ids = inputs["input_ids"]
            span_mask = inputs["span_mask"]
            token_input = torch.where(
                (input_ids * span_mask) > 1, self.tokenizer.pad_token_id, input_ids
            )
            timesteps = self.model.warp_timesteps(
                timesteps,
                token_input=token_input,
                span_mask=span_mask,
                t_max=len(self.noise_scheduler) - 1,
            )
        # Adds noise to each simplex representation (Forward diffusion process).
        noisy_simplex = self.noise_scheduler.add_noise(simplex, noise, timesteps)
        # the warper model will scale the timesteps to the correct range.
        timesteps = scale(timesteps, len(self.noise_scheduler))
        # original_timesteps_scaled = scale(original_timesteps, len(self.noise_scheduler))
        # inputs.update(
        #     {"original_timesteps": scale(original_timesteps, len(self.noise_scheduler))}
        # )

        inputs.update(
            {
                "timesteps": timesteps,
                "simplex": noisy_simplex,
            }
        )
        # inputs.update({"max_timestep": len(self.noise_scheduler)})
        if self.diffusion_args.self_condition is not None:
            previous_pred = None
            # previous_hidden = None
            if self.self_cond_generator.random(1) > 0.5:
                next_timestep = inputs.pop("timesteps")
                next_simplex = inputs.pop("simplex")
                timesteps = torch.clamp(
                    (next_timestep * len(self.noise_scheduler)) + 1,
                    max=len(self.noise_scheduler) - 1,
                )
                if is_cdcd_check(self.model):
                    input_ids = inputs["input_ids"]
                    span_mask = inputs["span_mask"]
                    token_input = torch.where(
                        (input_ids * span_mask) > 1,
                        self.tokenizer.pad_token_id,
                        input_ids,
                    )
                    timesteps = self.model.warp_timesteps(
                        timesteps,
                        token_input=token_input,
                        span_mask=span_mask,
                        t_max=len(self.noise_scheduler) - 1,
                    )
                noisy_simplex = self.noise_scheduler.add_noise(
                    simplex, noise, timesteps
                )
                timesteps = scale(timesteps, len(self.noise_scheduler))
                inputs.update(
                    {
                        "timesteps": timesteps,
                        "simplex": noisy_simplex,
                    }
                )
                # we don't backprop through this.
                with torch.no_grad():
                    outputs = model(**inputs, previous_pred=previous_pred)
                logits_projection_fct = lambda x: logits_projection(  # noqa: E731
                    x,
                    self.diffusion_args.sampling_type,
                    self.diffusion_args.top_p,
                    self.diffusion_args.simplex_value,
                    self.diffusion_args.temperature,
                )
                previous_pred = self_condition_preds(
                    self.diffusion_args.self_condition,
                    outputs.logits,
                    logits_projection_fct,
                ).detach()
                # following rest of self-conditioning, don't backprop through.
                # previous_hidden = outputs.hidden_states.detach()
                # pop timestep/simplex and put the old ones back.
                inputs.update(
                    {
                        "timesteps": next_timestep,
                        "simplex": next_simplex,
                    }
                )
            inputs.update({"previous_pred": previous_pred})
            # inputs.update({"previous_hidden": previous_hidden})
        else:
            inputs.update({"previous_pred": None})
            # inputs.update({"previous_hidden": None})
            # previous_hidden = None
        # NOTE: we do this after computation of self-conditioning to not affect that one.
        # inputs.update(
        #     {"classifier_free_guidance_in_train": self.classifier_free_guidance}
        # )
        # re-warp based on previous hidden state
        if is_cdcd_check(self.model):
            # replace masked tokens with <mask> token.
            input_ids = inputs["input_ids"]
            span_mask = inputs["span_mask"]
            token_input = torch.where(
                (input_ids * span_mask) > 1, self.tokenizer.pad_token_id, input_ids
            )
            timesteps = self.model.warp_timesteps(
                original_timesteps,
                t_max=len(self.noise_scheduler) - 1,
                token_input=token_input,
                span_mask=span_mask,
            )
            noisy_simplex = self.noise_scheduler.add_noise(simplex, noise, timesteps)
            timesteps = scale(timesteps, len(self.noise_scheduler))
            inputs.update(
                {
                    "timesteps": timesteps,
                    "simplex": noisy_simplex,
                }
            )
        with self.compute_loss_context_manager():
            loss = self.compute_loss(model, inputs)

        if self.args.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu parallel training
        # HACK: transformer update
        # if self.do_grad_scaling:
        #     self.scaler.scale(loss).backward()
        elif self.use_apex:
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            self.accelerator.backward(loss)
        return loss.detach() / self.args.gradient_accumulation_steps

    def light_prediction_step(
        self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]
    ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
        with torch.no_grad():
            inputs = self._prepare_inputs(inputs)
            # Truncate the length if needed.
            if self.data_args.truncation_length > 0:
                inputs["input_ids"] = inputs["input_ids"][
                    :, : -self.data_args.truncation_length
                ]
                inputs["span_mask"] = inputs["span_mask"][
                    :, : -self.data_args.truncation_length
                ]
            # Creates the noisy simplex and timesteps.
            simplex = convert_to_simplex(
                inputs["input_ids"], self.diffusion_args.simplex_value, self.vocab_size
            )
            noise = self.diffusion_args.simplex_value * torch.randn(
                simplex.shape, device=simplex.device, dtype=simplex.dtype
            )
            bsz = simplex.shape[0]
            # Sample a random timestep for each simplex token representation.
            # we use the train timesteps to be consistent with the training process.
            # randomly flip between random batchwise and tokenwise timesteps.
            if True:
                timesteps = torch.randint(
                    0,
                    len(self.noise_scheduler),
                    (bsz, inputs["input_ids"].shape[1])
                    if False  # is_tokenwise_cdcd_check(self.model)
                    else (bsz,),
                    device=simplex.device,
                    dtype=torch.int64,
                )
                timesteps = timesteps[:, None].expand(-1, inputs["input_ids"].shape[1])
            else:
                timesteps = torch.randint(
                    0,
                    len(self.noise_scheduler),
                    (bsz, inputs["input_ids"].shape[1])
                    if True  # is_tokenwise_cdcd_check(self.model)
                    else (bsz,),
                    device=simplex.device,
                    dtype=torch.int64,
                )
            # original_timesteps = timesteps

            # if cdcd, we need to wrap the timesteps in a cdf.
            # make sure we scale the timesteps to the correct range!
            if is_cdcd_check(self.model):
                input_ids = inputs["input_ids"]
                span_mask = inputs["span_mask"]
                token_input = torch.where(
                    (input_ids * span_mask) > 1, self.tokenizer.pad_token_id, input_ids
                )
                timesteps = self.model.warp_timesteps(
                    timesteps,
                    t_max=len(self.noise_scheduler) - 1,
                    token_input=token_input,
                    span_mask=span_mask,
                )

            # Adds noise to each simplex representation (Forward diffusion process).
            noisy_simplex = self.noise_scheduler.add_noise(simplex, noise, timesteps)

            timesteps = scale(timesteps, len(self.noise_scheduler))
            # original_timesteps_scaled = scale(
            #     original_timesteps, len(self.noise_scheduler)
            # )
            # inputs.update({"original_timesteps": original_timesteps_scaled})

            inputs.update(
                {
                    "timesteps": timesteps,
                    "simplex": noisy_simplex,
                }
            )
            # inputs.update({"max_timestep": len(self.noise_scheduler)})
            if self.diffusion_args.self_condition is not None:
                previous_pred = None
                # last_hidden_state = None
                if np.random.rand(1) > 0.5:
                    outputs = model(**inputs, previous_pred=previous_pred)
                    logits_projection_fct = lambda x: logits_projection(  # noqa: E731
                        x,
                        self.diffusion_args.sampling_type,
                        self.diffusion_args.top_p,
                        self.diffusion_args.simplex_value,
                        self.diffusion_args.temperature,
                    )
                    previous_pred = self_condition_preds(
                        self.diffusion_args.self_condition,
                        outputs.logits,
                        logits_projection_fct,
                    )
                    # last_hidden_state = outputs.hidden_states
                inputs.update(
                    {
                        "previous_pred": previous_pred,
                        # "previous_hidden": last_hidden_state,
                    }
                )
            # NOTE: we do this after computation of self-conditioning to not affect that one.
            # inputs.update(
            #     {"classifier_free_guidance_in_train": self.classifier_free_guidance}
            # )
            with self.compute_loss_context_manager():
                loss = self.compute_loss(model, inputs)

            if self.args.n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu parallel training

            return (
                loss.detach()
            )  # no division by gradient accumulation steps for eval. we want per-sample avg loss.

    # TODO: argument for doing one step.
    def prediction_step(
        self,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        model: nn.Module,
        pipeline: List[SimplexDDPMClassifierGuidancePipeline],
    ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
        inputs = self._prepare_inputs(inputs)
        # full inference.
        with torch.no_grad():
            with self.compute_loss_context_manager():
                for i, x in enumerate(
                    pipeline(
                        seq_length=self.data_args.max_seq_length
                        - self.data_args.truncation_length,
                        batch=inputs,
                        guidance_scale=self.diffusion_args.guidance_scale,
                        generator=torch.Generator(device=self.args.device).manual_seed(
                            self.args.seed
                        )
                        if self.diffusion_args.generate_with_seed
                        else None,
                        is_generator=False,
                        use_gumbel_softmax=self.diffusion_args.use_gumbel_softmax,
                        do_hard_sample=self.diffusion_args.do_hard_sample,
                        softmax_temperature=self.diffusion_args.softmax_temperature,
                        num_guidance_steps=self.diffusion_args.num_guidance_steps,
                    )
                ):
                    outputs = x
        logits = nested_detach(outputs.logits)
        simplex = nested_detach(outputs.simplex)

        return (simplex, logits)

    def evaluation_loop(
        self,
        dataloader: DataLoader,
        description: str,
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
        noise_scheduler=None,
        light_eval_dataloader=None,
        do_light_eval=False,
    ) -> EvalLoopOutput:
        """
        Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
        Works both with or without labels.
        """
        args = self.args
        is_conditional_generation = self.data_args.conditional_generation is not None
        save_prefixes = is_conditional_generation

        prediction_loss_only = (
            prediction_loss_only
            if prediction_loss_only is not None
            else args.prediction_loss_only
        )
        # if eval is called w/o train handle model prep here
        if self.is_deepspeed_enabled and self.model_wrapped is self.model:
            _, _ = deepspeed_init(self, num_training_steps=0, inference=True)

        model = self._wrap_model(self.model, training=False, dataloader=dataloader)

        if len(self.accelerator._models) == 0 and model is self.model:
            model = (
                self.accelerator.prepare(model)
                if self.is_deepspeed_enabled
                else self.accelerator.prepare_model(model, evaluation_mode=True)
            )

            if self.is_fsdp_enabled:
                self.model = model

            # for the rest of this function `model` is the outside model, whether it was wrapped or not
            if model is not self.model:
                self.model_wrapped = model

            # backward compatibility
            if self.is_deepspeed_enabled:
                self.deepspeed = self.model_wrapped

        # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
        # while ``train`` is running, cast it to the right dtype first and then put on device
        if not self.is_in_train:
            if args.fp16_full_eval:
                model = model.to(dtype=torch.float16, device=args.device)
            elif args.bf16_full_eval:
                model = model.to(dtype=torch.bfloat16, device=args.device)

        batch_size = self.args.eval_batch_size

        logger.info(f"***** Running {description} *****")
        if has_length(dataloader):
            logger.info(f"  Num examples = {self.num_examples(dataloader)}")
        else:
            logger.info("  Num examples: Unknown")
        logger.info(f"  Batch size = {batch_size}")

        model.eval()

        pipeline = SimplexDDPMClassifierGuidancePipeline(
            model=model,
            scheduler=noise_scheduler,
            simplex_value=self.diffusion_args.simplex_value,
            top_p=self.diffusion_args.top_p,
            sampling_type=self.diffusion_args.sampling_type,
            is_conditional_generation=is_conditional_generation,
            tokenizer=self.tokenizer,
            classifier_free_uncond_input=self.diffusion_args.classifier_free_uncond_input,
            temperature=self.diffusion_args.temperature,
            guidance_softmax_combination=self.diffusion_args.guidance_softmax_combination,
            classifier_model_name_or_path=self.diffusion_args.classifier_model_name_or_path,
        )

        self.callback_handler.eval_dataloader = dataloader
        # Do this before wrapping.
        eval_dataset = getattr(dataloader, "dataset", None)

        # Initialize containers
        # logits/simplex/labels on GPU/TPU (accumulated for eval_accumulation_steps)
        losses_host = None
        logits_host = None
        simplex_host = None
        inputs_host = None
        masks_host = None
        prefixes_host = None

        # logits/simplex/labels on CPU (final containers)
        all_losses = None
        all_logits = None
        all_simplex = None
        all_inputs = None
        all_masks = None
        all_prefixes = None
        observed_num_examples = 0

        # light evaluation loop.
        if light_eval_dataloader is not None and do_light_eval:
            for step, inputs in enumerate(light_eval_dataloader):
                # Truncate the length if needed.
                if self.data_args.truncation_length > 0:
                    inputs["input_ids"] = inputs["input_ids"][
                        :, : -self.data_args.truncation_length
                    ]
                    inputs["span_mask"] = inputs["span_mask"][
                        :, : -self.data_args.truncation_length
                    ]
                    max_seq_length = (
                        self.data_args.max_seq_length - self.data_args.truncation_length
                    )
                    assert self.data_args.eval_context_size < max_seq_length
                # predict loss mimicking training.
                loss = self.light_prediction_step(model, inputs)

                if loss is not None:
                    losses = self._nested_gather(loss.repeat(batch_size))
                    losses_host = (
                        losses
                        if losses_host is None
                        else torch.cat((losses_host, losses), dim=0)
                    )
                if (
                    args.eval_accumulation_steps is not None
                    and (step + 1) % args.eval_accumulation_steps == 0
                ):
                    if losses_host is not None:
                        losses = nested_numpify(losses_host)
                        all_losses = (
                            losses
                            if all_losses is None
                            else np.concatenate((all_losses, losses), axis=0)
                        )
                    losses_host = None

        # Main evaluation loop
        for step, inputs in enumerate(dataloader):
            has_mask = True if "span_mask" in inputs else False

            # Truncate the length if needed.
            if self.data_args.truncation_length > 0:
                inputs["input_ids"] = inputs["input_ids"][
                    :, : -self.data_args.truncation_length
                ]
                inputs["span_mask"] = inputs["span_mask"][
                    :, : -self.data_args.truncation_length
                ]
                max_seq_length = (
                    self.data_args.max_seq_length - self.data_args.truncation_length
                )
                assert self.data_args.eval_context_size < max_seq_length

            # Update the observed num examples
            observed_batch_size = find_batch_size(inputs)
            if observed_batch_size is not None:
                observed_num_examples += observed_batch_size
                # For batch samplers, batch_size is not known by the dataloader in advance.
                if batch_size is None:
                    batch_size = observed_batch_size

            # Prediction step
            simplex, logits = self.prediction_step(inputs, model, pipeline=pipeline)
            inputs_decode = self._prepare_input(inputs["input_ids"])
            masks = self._prepare_input(inputs["span_mask"]) if has_mask else None
            if save_prefixes:
                prefixes = (
                    pad_data(
                        [input[~mask] for input, mask in zip(inputs_decode, masks)],
                        self.tokenizer,
                    )
                    if has_mask
                    else None
                )
                prefixes = self._prepare_input(prefixes)
            else:
                prefixes = None
            # Update containers on host
            if prefixes is not None:
                prefixes = self.accelerator.pad_across_processes(
                    prefixes, dim=1, pad_index=self.eos_token_id
                )
                prefixes = self._nested_gather(prefixes)
                prefixes_host = (
                    prefixes
                    if prefixes_host is None
                    else nested_concat(
                        prefixes_host, prefixes, padding_index=self.eos_token_id
                    )
                )
            if inputs_decode is not None:
                inputs_decode = self.accelerator.pad_across_processes(
                    inputs_decode, dim=1, pad_index=self.eos_token_id
                )
                inputs_decode = self._nested_gather(inputs_decode)
                inputs_host = (
                    inputs_decode
                    if inputs_host is None
                    else nested_concat(
                        inputs_host, inputs_decode, padding_index=self.eos_token_id
                    )
                )
            # Note that this block should be before masks block, since we need masks here.
            if simplex is not None:
                # In case of having a mask softmax is applied over the simplex non-masked values.
                if has_mask:
                    mask_value = torch.finfo(simplex.dtype).min
                    mask_value = torch.tensor(
                        mask_value, dtype=simplex.dtype, device=simplex.device
                    )
                    simplex = torch.where(masks[:, :, None], simplex, mask_value)
                simplex = F.softmax(simplex, dim=-1)
                if self.preprocess_logits_for_metrics is not None:
                    simplex = self.preprocess_logits_for_metrics(simplex)
                simplex = self.accelerator.pad_across_processes(
                    simplex, dim=1, pad_index=self.eos_token_id
                )
                simplex = self._nested_gather(simplex)
                # TODO: note that this is no more a simplex, but the processed one.
                simplex_host = (
                    simplex
                    if simplex_host is None
                    else nested_concat(
                        simplex_host, simplex, padding_index=self.eos_token_id
                    )
                )
            if masks is not None:
                masks = self.accelerator.pad_across_processes(masks, dim=1, pad_index=0)
                masks = self._nested_gather(masks)
                # We pad masks with False tokens.
                masks_host = (
                    masks
                    if masks_host is None
                    else nested_concat(masks_host, masks, padding_index=0)
                )
            if logits is not None:
                if self.preprocess_logits_for_metrics is not None:
                    logits = self.preprocess_logits_for_metrics(logits)
                logits = self.accelerator.pad_across_processes(
                    logits, dim=1, pad_index=self.eos_token_id
                )
                logits = self._nested_gather(logits)
                logits_host = (
                    logits
                    if logits_host is None
                    else nested_concat(
                        logits_host, logits, padding_index=self.eos_token_id
                    )
                )

            self.control = self.callback_handler.on_prediction_step(
                args, self.state, self.control
            )

        # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
        if (
            args.eval_accumulation_steps is not None
            and (step + 1) % args.eval_accumulation_steps == 0
        ):
            if logits_host is not None:
                logits = nested_numpify(logits_host)
                all_logits = (
                    logits
                    if all_logits is None
                    else nested_concat(
                        all_logits, logits, padding_index=self.eos_token_id
                    )
                )
            if simplex_host is not None:
                simplex = nested_numpify(simplex_host)
                all_simplex = (
                    simplex
                    if all_simplex is None
                    else nested_concat(
                        all_simplex, simplex, padding_index=self.eos_token_id
                    )
                )
            if inputs_host is not None:
                inputs_decode = nested_numpify(inputs_host)
                all_inputs = (
                    inputs_decode
                    if all_inputs is None
                    else nested_concat(
                        all_inputs, inputs_decode, padding_index=self.eos_token_id
                    )
                )
            if masks_host is not None:
                masks = nested_numpify(masks_host)
                all_masks = (
                    masks
                    if all_masks is None
                    else nested_concat(all_masks, masks, padding_index=0)
                )
            if prefixes_host is not None:
                prefixes = nested_numpify(prefixes_host)
                all_prefixes = (
                    prefixes
                    if all_prefixes is None
                    else nested_concat(
                        all_prefixes, prefixes, padding_index=self.eos_token_id
                    )
                )

            # Set back to None to begin a new accumulation
            logits_host, simplex_host, inputs_host, masks_host, prefixes_host = (
                None,
                None,
                None,
                None,
                None,
            )

        # Gather all remaining tensors and put them back on the CPU
        if losses_host is not None:
            all_losses = nested_numpify(losses_host)
        if logits_host is not None:
            all_logits = nested_numpify(logits_host)
        if simplex_host is not None:
            all_simplex = nested_numpify(simplex_host)
        if inputs_host is not None:
            all_inputs = nested_numpify(inputs_host)
        if masks_host is not None:
            all_masks = nested_numpify(masks_host)
        if prefixes_host is not None:
            all_prefixes = nested_numpify(prefixes_host)

        if args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")

        # Number of samples
        if has_length(eval_dataset):
            num_samples = len(eval_dataset)
        # The instance check is weird and does not actually check for the type, but whether the dataset has the right
        # methods. Therefore we need to make sure it also has the attribute.
        elif (
            isinstance(eval_dataset, IterableDatasetShard)
            and getattr(eval_dataset, "num_examples", 0) > 0
        ):
            num_samples = eval_dataset.num_examples
        else:
            if has_length(dataloader):
                num_samples = self.num_examples(dataloader)
            else:  # both len(dataloader.dataset) and len(dataloader) fail
                num_samples = observed_num_examples
        if num_samples == 0 and observed_num_examples > 0:
            num_samples = observed_num_examples

        # Generates the texts.
        results = {}
        if is_conditional_generation:
            # We predict the masked tokens only. Here, we compute the masked tokens.
            results.update(
                predict_conditional_generated(
                    all_masks,
                    all_inputs,
                    self.tokenizer,
                    all_simplex,
                    "pred_texts_from_simplex",
                    self.data_args.skip_special_tokens,
                )
            )
            results.update(
                predict_conditional_generated(
                    all_masks,
                    all_inputs,
                    self.tokenizer,
                    all_logits,
                    "pred_texts_from_logits",
                    self.data_args.skip_special_tokens,
                )
            )
        else:
            results.update(
                {
                    "pred_texts_from_simplex": self.tokenizer.batch_decode(
                        all_simplex,
                        skip_special_tokens=self.data_args.skip_special_tokens,
                    )
                }
            )
            results.update(
                {
                    "pred_texts_from_logits": self.tokenizer.batch_decode(
                        all_logits,
                        skip_special_tokens=self.data_args.skip_special_tokens,
                    )
                }
            )
        if is_conditional_generation:
            results.update(
                {
                    "gold_texts_masked": [
                        self.tokenizer.decode(
                            input[mask],
                            skip_special_tokens=self.data_args.skip_special_tokens,
                        )
                        for mask, input in zip(all_masks, all_inputs)
                    ]
                }
            )
            if save_prefixes:
                results.update(
                    {
                        "prefixes": [
                            self.tokenizer.decode(
                                x, skip_special_tokens=True
                            )  # self.data_args.skip_special_tokens)
                            for x in all_prefixes
                        ]
                    }
                )

        # Metrics.
        if self.compute_metrics is not None:
            metrics = self.compute_metrics(results)
        else:
            metrics = {}

        # To be JSON-serializable, we need to remove numpy types or zero-d tensors
        metrics = denumpify_detensorize(metrics)

        if all_losses is not None:
            metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()

        # Prefix all keys with metric_key_prefix + '_'
        for key in list(metrics.keys()):
            if not key.startswith(f"{metric_key_prefix}_"):
                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)

        return EvalLoopOutput(
            logits=all_logits,
            simplex=all_simplex,
            input_ids=all_inputs,
            metrics=metrics,
            num_samples=num_samples,
            results=results,
        )

    def evaluate(
        self,
        eval_dataset: Optional[Dataset] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
    ) -> Dict[str, float]:
        """
        Run evaluation and returns metrics.
        The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
        (pass it to the init `compute_metrics` argument).
        You can also subclass and override this method to inject custom behavior.
        Args:
            eval_dataset (`Dataset`, *optional*):
                Pass a dataset if you wish to override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns
                not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`
                method.
            ignore_keys (`Lst[str]`, *optional*):
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
            metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
                An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
                "eval_bleu" if the prefix is "eval" (default)
        Returns:
            A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
            dictionary also contains the epoch number which comes from the training state.
        """
        # memory metrics - must set up as early as possible
        self._memory_tracker.start()
        eval_dataloader = self.get_eval_dataloader(eval_dataset)
        light_eval_dataloader = self.get_light_eval_dataloader(eval_dataset)
        start_time = time.time()

        outputs = []
        timesteps = self.inference_timesteps
        for timestep, noise_scheduler in zip(
            timesteps, self.inference_noise_schedulers
        ):
            output = self.evaluation_loop(
                eval_dataloader,
                description="Evaluation",
                # No point gathering the predictions if there are no metrics, otherwise we defer to
                # self.args.prediction_loss_only
                prediction_loss_only=True if self.compute_metrics is None else None,
                ignore_keys=ignore_keys,
                metric_key_prefix=metric_key_prefix,
                noise_scheduler=noise_scheduler,
                light_eval_dataloader=light_eval_dataloader,
                do_light_eval=timestep
                == timesteps[
                    0
                ],  # we only need the loss once, since it is the same for all timesteps
            )
            outputs.append(output)
            key_prefix = f"inference_{timestep}_"
            metrics = {key_prefix + k: v for k, v in output.metrics.items()}
            results = {key_prefix + k: v for k, v in output.results.items()}
            # reset output with new metrics / results
            output = EvalLoopOutput(
                logits=output.logits,
                simplex=output.simplex,
                input_ids=output.input_ids,
                metrics=metrics,
                num_samples=output.num_samples,
                results=results,
            )

            total_batch_size = self.args.eval_batch_size * self.args.world_size
            output.metrics.update(
                speed_metrics(
                    metric_key_prefix,
                    start_time,
                    num_samples=output.num_samples,
                    num_steps=math.ceil(output.num_samples / total_batch_size),
                )
            )
            self.log(output.metrics)
            self.control = self.callback_handler.on_evaluate(
                self.args, self.state, self.control, output.metrics
            )
            self._memory_tracker.stop_and_update_metrics(output.metrics)

            # Save the results
            self.save_metrics(
                GENERATION_RESULTS + "_" + key_prefix + metric_key_prefix,
                output.results,
            )
            logger.info("Results are saved now")

        # log outside so we can group generations together
        if self.args.log_generated_texts:
            length = len(outputs[0].logits)
            results = {
                f"{k}_inference_{i}": v
                for o, i in zip(outputs, timesteps)
                for k, v in o.results.items()
            }
            self.log_results_to_tensorboard(self.state, length, results)

        return output.metrics

    def log_results_to_tensorboard(self, state, length, results):
        # TODO: we need to fix this which happens during the only eval option.
        if self.tb_writer.tb_writer is None:
            return
        for i in range(length):
            total_text = ""
            for k, v in results.items():
                total_text += f"*** {k} ***: {v[i]}" + "  \n"
            self.tb_writer.tb_writer.add_text(
                f"sample_{i}", total_text, state.global_step
            )

    def get_train_dataloader(self) -> DataLoader:
        self.data_collator = self.original_data_collator("train")
        return super().get_train_dataloader()

    def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
        self.data_collator = self.original_data_collator("eval")
        return super().get_eval_dataloader(eval_dataset)

    def get_light_eval_dataloader(
        self, eval_dataset: Optional[Dataset] = None
    ) -> DataLoader:
        """
        Returns the evaluation [`~torch.utils.data.DataLoader`].
        Used for the light evaluation, which matches masking with training.
        Args:
            eval_dataset (`torch.utils.data.Dataset`, *optional*):
                If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
                by the `model.forward()` method are automatically removed. It must implement `__len__`.
        """
        if eval_dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
        data_collator = self.original_data_collator("train")

        if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
            eval_dataset = self._remove_unused_columns(
                eval_dataset, description="evaluation"
            )
        else:
            data_collator = self._get_collator_with_removed_columns(
                data_collator, description="evaluation"
            )

        dataloader_params = {
            "batch_size": self.args.eval_batch_size,
            "collate_fn": data_collator,
            "num_workers": self.args.dataloader_num_workers,
            "pin_memory": self.args.dataloader_pin_memory,
        }

        if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
            dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset)
            dataloader_params["drop_last"] = self.args.dataloader_drop_last

        return self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))

    def create_optimizer(self):
        from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
        from transformers.trainer_pt_utils import get_parameter_names

        opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model

        if self.optimizer is not None:
            return self.optimizer

        decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
        decay_parameters = [name for name in decay_parameters if "bias" not in name]
        optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
            self.args
        )

        # override to apply higher lr to timestep_embed and cdcd cdf
        optimizer_grouped_parameters = [
            {
                "params": [
                    p
                    for n, p in opt_model.named_parameters()
                    if (
                        n in decay_parameters
                        and p.requires_grad
                        and not ("timestep_embed" in n or "cdf" in n)
                    )
                ],
                "weight_decay": self.args.weight_decay,
                "lr": optimizer_kwargs["lr"],
            },
            {
                "params": [
                    p
                    for n, p in opt_model.named_parameters()
                    if (
                        n not in decay_parameters
                        and p.requires_grad
                        and not ("timestep_embed" in n or "cdf" in n)
                    )
                ],
                "weight_decay": 0.0,
                "lr": optimizer_kwargs["lr"],
            },
            {
                "params": [
                    p
                    for n, p in opt_model.named_parameters()
                    if (("timestep_embed" in n) and p.requires_grad)
                ],
                "weight_decay": 0.0,
                "lr": self.args.timestep_embed_lr or self.args.learning_rate,
            },
        ]
        # check cdcd
        cdf_params = [
            p
            for n, p in opt_model.named_parameters()
            if (("cdf" in n) and p.requires_grad)
        ]
        if cdf_params:
            optimizer_grouped_parameters.append(
                {
                    "params": cdf_params,
                    "weight_decay": 0.0,
                    "lr": 1e-3,
                }
            )

        optimizer_kwargs.pop("lr")

        self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)

        return self.optimizer