import re
from datasets import load_dataset, Dataset, DatasetDict
from itertools import chain
from tqdm import tqdm
from collections import Counter
from accelerate import Accelerator

LANGUAGES_TO_DECODE_FROM_BYTES = ["he", "fr", "uk"]
STREAMING_DATASETS = ["fineweb-edu"]


def load_pg19_val_and_test():
    # Load the dataset in streaming mode
    streaming_dataset = load_dataset("deepmind/pg19", split=None, streaming=True)

    # Extract test and validation splits
    test_split = list(streaming_dataset["test"])
    validation_split = list(streaming_dataset["validation"])

    # Convert them into regular datasets
    test_dataset = Dataset.from_list(test_split)
    validation_dataset = Dataset.from_list(validation_split)

    # validation_dataset = load_dataset("deepmind/pg19", split="validation")
    # test_dataset = load_dataset("deepmind/pg19", split="test")

    return DatasetDict({"validation": validation_dataset, "test": test_dataset})


def load_pubmed(n_samples=10000):
    # Load the dataset in streaming mode
    streaming_dataset = load_dataset("MedRAG/pubmed", streaming=True)

    # Extract test and validation splits
    data = list(streaming_dataset["train"].take(n_samples*4))
    train = data[:2*n_samples]
    validation = data[2*n_samples:3*n_samples]
    test = data[3*n_samples:]
    # Convert them into regular datasets
    train = Dataset.from_list(train)
    validation = Dataset.from_list(validation)
    test = Dataset.from_list(test)
    dataset = DatasetDict({"train": train, 'validation': validation, 'test': test})
    dataset = dataset.rename_column('content', 'text')
    return dataset


def load_lm_dataset(dataset_name, language="en", split=None):
    """
    Loads a popular pretraining or perplexity evaluation dataset by name and language.

    Args:
        dataset_name (str): The name of the dataset to load. Options include:
            - 'wikitext' (wikitext-2, smaller WikiText dataset)
            - 'wikitext-103' (larger WikiText dataset)
            - 'pg19' (Project Gutenberg dataset for long-context modeling)
            - 'c4' (Common Crawl-based English corpus)
            - 'wiki40b' (Wikipedia dataset in multiple languages)
            - 'mc4' (Multilingual C4 dataset in various languages)
        language (str): Language code for datasets that support multilingual options (e.g., 'en' for English).
                        Defaults to 'en'.

    Returns:
        Dataset: Loaded Hugging Face dataset.
    """
    if dataset_name.lower() == 'wikitext':
        return load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split=split)
    elif dataset_name.lower() == 'fineweb-edu':
        return load_dataset("HuggingFaceFW/fineweb-edu", name="sample-10BT")
    elif dataset_name.lower() == 'wikitext-103':
        return load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split=split)
    elif dataset_name.lower() == 'cord19':
        return load_dataset("allenai/cord19", "fulltext", trust_remote_code=True)
    elif dataset_name.lower() == 'pubmed':
        return load_pubmed()
    elif dataset_name.lower() == 'wikilingua':
        dataset = load_dataset("GEM/wiki_lingua", trust_remote_code=True)
        dataset = dataset.filter(lambda ex: (ex['source_language'] == "en") & (ex['target_language'] == "en"))
        dataset = dataset.rename_column("source", "text")
        dataset = dataset.rename_column("target", "summary")
        return dataset
    elif dataset_name.lower() == 'xsum':
        dataset = load_dataset("EdinburghNLP/xsum")
        dataset = dataset.rename_column("document", "text")
        return dataset
    elif dataset_name.lower() == 'cnn':
        dataset = load_dataset("abisee/cnn_dailymail", "3.0.0")
        dataset = dataset.rename_column("article", "text")
        dataset = dataset.rename_column("highlights", "summary")
        dataset = dataset.map(lambda example: {"text": example["text"].replace("(CNN)", "")})
        return dataset
    elif dataset_name.lower() == 'pg19':
        return load_pg19_val_and_test()
    elif dataset_name.lower() == 'wiki40b':
        dataset = load_dataset("google/wiki40b", language, split=split)
        if language in LANGUAGES_TO_DECODE_FROM_BYTES:
            dataset = dataset.map(lambda x: {
                "text": bytes(x["text"][2:-1], "utf-8").decode("unicode_escape").encode("latin1").decode("utf-8").replace("_NEWLINE_", "\n")
            })
        return dataset
    else:
        raise ValueError(
            "Dataset not recognized. Available options: 'wikitext-2', 'wikitext-103', 'pg19', 'c4', 'wiki40b', 'mc4'.")


def extract_new_words_from_dataset(
        dataset: Dataset, tokenizer, text_column: str = "text", max_samples: int = None, filter_func=(lambda word, token_count: True)):
    """
    Loads a Hugging Face dataset and extracts all unique words from the specified text column.

    Args:
        dataset (Dataset): Name of the dataset to load.
        split (str): Dataset split to use, typically 'train' for training data. Defaults to 'train'.
        text_column (str): The column in the dataset containing text. Defaults to 'text'.
        max_samples (int): Number of samples from the dataset to go over.

    Returns:
        set: A set of unique words in the dataset.
    """
    if max_samples:
        dataset = dataset.select(range(max_samples))

    # Regular expression to split text into words (adjust as needed for specific languages)
    # word_pattern = re.compile(r"\b\w+\b")
    word_pattern = re.compile(r"\b\w+(?:[-']\w+)*\b")

    # Iterate over each entry in the dataset and extract unique words
    all_words = list()
    new_words = list()
    for record in tqdm(dataset, total=len(dataset), miniters=10, desc="Extracting all words from dataset...", unit="examples"):
        text = record.get(text_column, "")
        words = word_pattern.findall(text)
        all_words += words

    # all_words = list(dict.fromkeys(all_words))
    word_frequencies = Counter(all_words)
    all_words = list(word_frequencies.keys())
    token_counts = [len(x) for x in tokenizer(all_words, add_special_tokens=False)["input_ids"]]
    w_whitespace_token_counts = [len(x) for x in tokenizer([f" {w}" for w in all_words], add_special_tokens=False)["input_ids"]]

    new_words = [word for word, count, w_whitespace_count in zip(all_words, token_counts, w_whitespace_token_counts) if ((count > 1) and (w_whitespace_count > 1) and filter_func(word, count))]
    new_words_freq = {word: word_frequencies[word] for word in new_words}
    # for word, token_count in tqdm(all_words, total=len(all_words), miniters=10, desc="Finding new words...", unit="words"):
    #     if (not tokenizer.vocab.get(word, False)) and :
    #         new_words.append(word)

    # remove duplicates and return
    return new_words, new_words_freq


def get_group_texts_func(block_size=1024):
    def group_texts(examples):
        # Concatenate all texts.
        concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}

        total_length = len(concatenated_examples[list(examples.keys())[0]])
        # We drop the small remainder, and if the total_length < block_size  we exclude this batch and return an empty dict.
        # We could add padding if the model supported it instead of this drop, you can customize this part to your needs.
        total_length = (total_length // block_size) * block_size
        # Split by chunks of max_len.
        result = {
            k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
            for k, t in concatenated_examples.items()
        }
        result["labels"] = result["input_ids"].copy()
        return result
    return group_texts


def get_tokenize_func(tokenizer, text_col_name):
    def _tokenize(examples):
        output = tokenizer(
            examples[text_col_name],
            return_token_type_ids=False,
            add_special_tokens=False,
        )
        return output
    return _tokenize


def tokenize_and_prepare_dataset(
        dataset, tokenizer, accelerator=None,
        text_col_name: str = "text",
        max_length: int = 256,
        eval_max_samples: int = None,
):

    if tokenizer.bos_token is not None and max_length:
        # leave room for <BOS> token to be added:
        max_tokenized_len = max_length - 1
    else:
        max_tokenized_len = max_length

    tokenize_function = get_tokenize_func(tokenizer, text_col_name)

    column_names = dataset.column_names

    tokenized_dataset = dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=column_names,
        load_from_cache_file=False,
        desc="Running tokenizer on dataset",
    )
    group_texts = get_group_texts_func(block_size=max_tokenized_len)
    lm_dataset = tokenized_dataset.map(
        group_texts,
        batched=True,
    )

    if eval_max_samples:
        lm_dataset = lm_dataset.select(range(eval_max_samples))

    return lm_dataset