import logging
import os
import sys

import datasets
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainerCallback, set_seed
from transformers.trainer_callback import TrainerState
from transformers.utils import check_min_version
from transformers.utils.versions import require_version

from .arguments import get_args
from .data.data_collator import SpanInfillingDataCollator
from .data.data_utils import load_data, tokenize_data_new
from .inference.inference_utils import evaluate_generation
from .models import get_torch_dtype, load_model
from .schedulers import TokenWiseSimplexDDPMScheduler
from .trainers.trainer_diffusion import DiffusionTrainer
from .utils import (
    get_last_checkpoint_with_beaker_preemption,
    is_nfs_available,
    is_weka_available,
    resolve_last_checkpoint_vs_resume_from_checkpoint,
    set_hf_home,
    set_pretraining_dataset,
)

# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.25.0")

require_version(
    "datasets>=2.0.0",
    "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt",
)

logger = logging.getLogger(__name__)

# set environment variables
set_hf_home()
os.environ["TOKENIZERS_PARALLELISM"] = "false"


def filter_by_length(min_len: int, pad_token_id: int) -> bool:
    """hashable filter function for hf dataset library"""

    def func(x):
        return min_len <= len([i for i in x["input_ids"] if i != pad_token_id])

    return func


def get_compute_metrics(data_args, training_args, model_args):
    # Causal language model.
    causal_model = AutoModelForCausalLM.from_pretrained(
        model_args.autoregressive_eval_model,
        torch_dtype=get_torch_dtype(training_args),
        attn_implementation="flash_attention_2"
        if model_args.use_flash_attention2
        else "eager",
    ).to(training_args.device)
    causal_tokenizer = AutoTokenizer.from_pretrained(
        model_args.autoregressive_eval_model
    )
    is_conditional_generation = data_args.conditional_generation is not None
    prefix_lm_eval = data_args.conditional_generation in [
        "prefix_lm",
        "ul2",
        "ul2_with_unconditional",
        "prefix_with_unconditional",
        "ul2_variable",
    ]
    compute_metrics = lambda results: evaluate_generation(  # noqa: E731
        results,
        data_args,
        causal_model,
        causal_tokenizer,
        is_conditional_generation,
        prefix_lm_eval=prefix_lm_eval,
        skip_special_tokens=data_args.skip_special_tokens,
        eval_for_all_metrics=training_args.eval_for_all_metrics,
    )
    return compute_metrics


# so we evaluate on the first step, useful for checking training is working.
class EvaluateFirstStepCallback(TrainerCallback):
    def on_step_end(self, args, state, control, **kwargs):
        if state.global_step == 1:
            control.should_evaluate = True


def main():
    # parse args
    model_args, data_args, training_args, diffusion_args = get_args()
    set_pretraining_dataset(data_args)

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )

    log_level = training_args.get_process_log_level()
    logger.setLevel(log_level)
    datasets.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()

    # Log on each process the small summary:
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
        + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )
    # Set the verbosity to info of the Transformers logger (on main process only):
    logger.info(f"Training/evaluation parameters {training_args}")

    # Detecting last checkpoint.
    last_checkpoint = get_last_checkpoint_with_beaker_preemption(training_args)

    # Set seed before initializing model.
    set_seed(training_args.seed)

    # load model
    tokenizer, model = load_model(
        model_args, data_args, training_args, diffusion_args, logger
    )
    assert model.config.pad_token_id is not None

    if training_args.do_train:
        raw_datasets = load_data(data_args, model_args)
        train_dataset = tokenize_data_new(
            data_args, tokenizer, raw_datasets, training_args
        )["train"]
        if data_args.max_train_samples is not None:
            max_train_samples = min(len(train_dataset), data_args.max_train_samples)
            train_dataset = train_dataset.select(range(max_train_samples))
        if data_args.min_train_seq_length != 0:
            train_dataset = train_dataset.filter(
                filter_by_length(
                    data_args.min_train_seq_length, model.config.pad_token_id
                )
            )
        if data_args.shuffle and data_args.streaming:
            train_dataset = train_dataset.shuffle(
                seed=training_args.seed, buffer_size=10_000
            )
        elif data_args.shuffle:
            train_dataset = train_dataset.shuffle(seed=training_args.seed)

    if training_args.do_eval:
        # default to c4
        if is_weka_available():
            data_file_path = "/data/input/jaket/c4_subset"
        elif is_nfs_available():
            data_file_path = (
                "/net/nfs.cirrascale/allennlp/jaket/simplex-diffusion/c4_subset"
            )
        else:
            # yale
            data_file_path = "/home/jt856/documents/simplex-diffusion/raw/c4_subset"
        c4_raw_dataset = datasets.IterableDatasetDict(
            {
                "validation": datasets.load_dataset(
                    "json",
                    data_files=os.path.join(
                        data_file_path, "c4-validation.00000-of-00008.json"
                    ),
                )["train"]
            }
        )
        c4_tokenized_datasets = tokenize_data_new(
            data_args, tokenizer, c4_raw_dataset, training_args
        )
        eval_dataset = c4_tokenized_datasets["validation"]
        if data_args.max_eval_samples is not None:
            max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
            eval_dataset = eval_dataset.select(range(max_eval_samples))
        if data_args.min_eval_seq_length != 0:
            eval_dataset = eval_dataset.filter(
                filter_by_length(
                    data_args.min_eval_seq_length, model.config.pad_token_id
                ),
                num_proc=data_args.preprocessing_num_workers,
            )

        def preprocess_logits_for_metrics(logits):
            return logits.argmax(dim=-1)

    # Data collator
    # TODO: fix lambda max_seq_length, extra_padding_ratio:
    pad_to_multiple_of_8 = (
        data_args.line_by_line
        and training_args.fp16
        and not data_args.pad_to_max_length
    )
    data_collator = lambda mode: SpanInfillingDataCollator(  # noqa: E731
        mode=mode,
        data_args=data_args,
        tokenizer=tokenizer,
        padding="max_length" if data_args.pad_to_max_length else True,
        max_length=data_args.max_seq_length,
        seed=training_args.seed,
        pad_to_multiple_of=8 if pad_to_multiple_of_8 else None,
        eval_context_size=data_args.eval_context_size,
    )

    compute_metrics = None
    if training_args.do_eval and not training_args.without_compute_metrics:
        # call only when necessary
        compute_metrics = get_compute_metrics(data_args, training_args, model_args)

    # init schedulers
    noise_scheduler = TokenWiseSimplexDDPMScheduler(
        num_train_timesteps=diffusion_args.num_diffusion_steps,
        beta_schedule=diffusion_args.beta_schedule,
        simplex_value=diffusion_args.simplex_value,
        clip_sample=diffusion_args.clip_sample,
        device=training_args.device,
        multiply_factor=diffusion_args.multiply_factor,
    )
    inference_noise_schedulers = [
        TokenWiseSimplexDDPMScheduler(
            num_train_timesteps=timesteps,
            beta_schedule=diffusion_args.beta_schedule,
            simplex_value=diffusion_args.simplex_value,
            clip_sample=diffusion_args.clip_sample,
            device=training_args.device,
            multiply_factor=diffusion_args.multiply_factor,
        )
        for timesteps in diffusion_args.num_inference_diffusion_steps
    ]

    # Initialize our Trainer
    trainer = DiffusionTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset if training_args.do_train else None,
        eval_dataset=eval_dataset if training_args.do_eval else None,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
        preprocess_logits_for_metrics=preprocess_logits_for_metrics
        if training_args.do_eval
        else None,
        noise_scheduler=noise_scheduler,
        diffusion_args=diffusion_args,
        data_args=data_args,
        inference_noise_schedulers=inference_noise_schedulers,
    )
    trainer.add_callback(EvaluateFirstStepCallback())

    # Training
    if training_args.do_train:
        checkpoint = resolve_last_checkpoint_vs_resume_from_checkpoint(
            last_checkpoint,
            training_args.resume_from_checkpoint,
        )
        train_result = trainer.train(resume_from_checkpoint=checkpoint)
        trainer.save_model()  # Saves the tokenizer too for easy upload
        metrics = train_result.metrics

        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()

    # Evaluation
    if training_args.do_eval:
        if training_args.load_states_in_eval_from_model_path:
            trainer._load_from_checkpoint(model_args.model_name_or_path)
            trainer.state = TrainerState.load_from_json(
                os.path.join(model_args.model_name_or_path, "trainer_state.json")
            )
            trainer._load_rng_state(model_args.model_name_or_path)

        # np.save("weights.npy", model.vocab_to_hidden_dim_embed.weight.data.numpy())

        logger.info("*** Evaluate ***")
        metrics = trainer.evaluate()
        max_eval_samples = (
            data_args.max_eval_samples
            if data_args.max_eval_samples is not None
            else len(eval_dataset)
        )
        metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)


if __name__ == "__main__":
    main()