import argparse
import json
import textwrap
from os import mkdir
from os.path import join as pjoin, isdir

from data_measurements import dataset_statistics
from data_measurements import dataset_utils



def load_or_prepare_widgets(ds_args, show_embeddings=False, use_cache=False):
    """
    Loader specifically for the widgets used in the app.
    Args:
        ds_args:
        show_embeddings:
        use_cache:

    Returns:

    """

    if not isdir(ds_args["cache_dir"]):
        print("Creating cache")
        # We need to preprocess everything.
        # This should eventually all go into a prepare_dataset CLI
        mkdir(ds_args["cache_dir"])


    dstats = dataset_statistics.DatasetStatisticsCacheClass(**ds_args,
                                                            use_cache=use_cache)
    # Header widget
    dstats.load_or_prepare_dset_peek()
    # General stats widget
    dstats.load_or_prepare_general_stats()
    # Labels widget
    try:
        dstats.set_label_field("label")
        dstats.load_or_prepare_labels()
    except:
        pass
    # Text lengths widget
    dstats.load_or_prepare_text_lengths()
    if show_embeddings:
        # Embeddings widget
        dstats.load_or_prepare_embeddings()
    # Text duplicates widget
    dstats.load_or_prepare_text_duplicates()
    # nPMI widget
    dstats.load_or_prepare_npmi()
    npmi_stats = dstats.npmi_stats
    # Handling for all pairs; in the UI, people select.
    do_npmi(npmi_stats)
    # Zipf widget
    dstats.load_or_prepare_zipf()


def load_or_prepare(dataset_args, do_html=False, use_cache=False):
    all = False
    dstats = dataset_statistics.DatasetStatisticsCacheClass(**dataset_args, use_cache=use_cache)
    print("Loading dataset.")
    dstats.load_or_prepare_dataset()
    print("Dataset loaded.  Preparing vocab.")
    dstats.load_or_prepare_vocab()
    print("Vocab prepared.")

    if not dataset_args["calculation"]:
        all = True

    if all or dataset_args["calculation"] == "general":
        print("\n* Calculating general statistics.")
        dstats.load_or_prepare_general_stats()
        print("Done!")
        print("Basic text statistics now available at %s." % dstats.general_stats_json_fid)
        print(
            "Text duplicates now available at %s." % dstats.dup_counts_df_fid
        )

    if all or dataset_args["calculation"] == "lengths":
        print("\n* Calculating text lengths.")
        fig_tok_length_fid = pjoin(dstats.cache_path, "lengths_fig.html")
        tok_length_json_fid = pjoin(dstats.cache_path, "lengths.json")
        dstats.load_or_prepare_text_lengths()
        with open(tok_length_json_fid, "w+") as f:
            json.dump(dstats.fig_tok_length.to_json(), f)
            print("Token lengths now available at %s." % tok_length_json_fid)
        if do_html:
            dstats.fig_tok_length.write_html(fig_tok_length_fid)
            print("Figure saved to %s." % fig_tok_length_fid)
        print("Done!")

    if all or dataset_args["calculation"] == "labels":
        if not dstats.label_field:
            print("Warning: You asked for label calculation, but didn't provide "
                  "the labels field name.  Assuming it is 'label'...")
            dstats.set_label_field("label")
            print("\n* Calculating label distribution.")
            dstats.load_or_prepare_labels()
            fig_label_html = pjoin(dstats.cache_path, "labels_fig.html")
            fig_label_json = pjoin(dstats.cache_path, "labels.json")
            dstats.fig_labels.write_html(fig_label_html)
            with open(fig_label_json, "w+") as f:
                json.dump(dstats.fig_labels.to_json(), f)
            print("Done!")
            print("Label distribution now available at %s." % dstats.label_dset_fid)
            print("Figure saved to %s." % fig_label_html)

    if all or dataset_args["calculation"] == "npmi":
        print("\n* Preparing nPMI.")
        npmi_stats = dataset_statistics.nPMIStatisticsCacheClass(
            dstats, use_cache=use_cache
        )
        do_npmi(npmi_stats, use_cache=use_cache)
        print("Done!")
        print(
            "nPMI results now available in %s for all identity terms that "
            "occur more than 10 times and all words that "
            "co-occur with both terms."
            % npmi_stats.pmi_cache_path
        )

    if all or dataset_args["calculation"] == "zipf":
        print("\n* Preparing Zipf.")
        zipf_fig_fid = pjoin(dstats.cache_path, "zipf_fig.html")
        zipf_json_fid = pjoin(dstats.cache_path, "zipf_fig.json")
        dstats.load_or_prepare_zipf()
        zipf_fig = dstats.zipf_fig
        with open(zipf_json_fid, "w+") as f:
            json.dump(zipf_fig.to_json(), f)
        zipf_fig.write_html(zipf_fig_fid)
        print("Done!")
        print("Zipf results now available at %s." % dstats.zipf_fid)
        print(
            "Figure saved to %s, with corresponding json at %s."
            % (zipf_fig_fid, zipf_json_fid)
        )

    # Don't do this one until someone specifically asks for it -- takes awhile.
    if dataset_args["calculation"] == "embeddings":
        print("\n* Preparing text embeddings.")
        dstats.load_or_prepare_embeddings()


def do_npmi(npmi_stats, use_cache=True):
    available_terms = npmi_stats.load_or_prepare_npmi_terms()
    completed_pairs = {}
    print("Iterating through terms for joint npmi.")
    for term1 in available_terms:
        for term2 in available_terms:
            if term1 != term2:
                sorted_terms = tuple(sorted([term1, term2]))
                if sorted_terms not in completed_pairs:
                    term1, term2 = sorted_terms
                    print("Computing nPMI statistics for %s and %s" % (term1, term2))
                    _ = npmi_stats.load_or_prepare_joint_npmi(sorted_terms)
                    completed_pairs[tuple(sorted_terms)] = {}


def get_text_label_df(
    ds_name,
    config_name,
    split_name,
    text_field,
    label_field,
    calculation,
    out_dir,
    do_html=False,
    use_cache=True,
):
    if not use_cache:
        print("Not using any cache; starting afresh")
    ds_name_to_dict = dataset_utils.get_dataset_info_dicts(ds_name)
    if label_field:
        label_field, label_names = (
            ds_name_to_dict[ds_name][config_name]["features"][label_field][0]
            if len(ds_name_to_dict[ds_name][config_name]["features"][label_field]) > 0
            else ((), [])
        )
    else:
        label_field = ()
        label_names = []
    dataset_args = {
        "dset_name": ds_name,
        "dset_config": config_name,
        "split_name": split_name,
        "text_field": text_field,
        "label_field": label_field,
        "label_names": label_names,
        "calculation": calculation,
        "cache_dir": out_dir,
    }
    load_or_prepare_widgets(dataset_args, use_cache=use_cache)


def main():
    # TODO: Make this the Hugging Face arg parser
    parser = argparse.ArgumentParser(
        formatter_class=argparse.RawDescriptionHelpFormatter,
        description=textwrap.dedent(
            """

         Example for hate speech18 dataset:
         python3 run_data_measurements.py --dataset="hate_speech18" --config="default" --split="train" --feature="text"

         Example for IMDB dataset:
         python3 run_data_measurements.py --dataset="imdb" --config="plain_text" --split="train" --label_field="label" --feature="text"
         """
        ),
    )

    parser.add_argument(
        "-d", "--dataset", required=True, help="Name of dataset to prepare"
    )
    parser.add_argument(
        "-c", "--config", required=True, help="Dataset configuration to prepare"
    )
    parser.add_argument(
        "-s", "--split", required=True, type=str, help="Dataset split to prepare"
    )
    parser.add_argument(
        "-f",
        "--feature",
        required=True,
        type=str,
        default="text",
        help="Text column to prepare",
    )
    parser.add_argument(
        "-w",
        "--calculation",
        help="""What to calculate (defaults to everything except embeddings).\n
                                                    Options are:\n

                                                    - `general` (for duplicate counts, missing values, length statistics.)\n

                                                    - `lengths` for text length distribution\n

                                                    - `labels` for label distribution\n

                                                    - `embeddings` (Warning: Slow.)\n

                                                    - `npmi` for word associations\n

                                                    - `zipf` for zipfian statistics
                                                    """,
    )
    parser.add_argument(
        "-l",
        "--label_field",
        type=str,
        required=False,
        default="",
        help="Field name for label column in dataset (Required if there is a label field that you want information about)",
    )
    parser.add_argument(
        "--cached",
        default=False,
        required=False,
        action="store_true",
        help="Whether to use cached files (Optional)",
    )
    parser.add_argument(
        "--do_html",
        default=False,
        required=False,
        action="store_true",
        help="Whether to write out corresponding HTML files (Optional)",
    )
    parser.add_argument("--out_dir", default="cache_dir", help="Where to write out to.")

    args = parser.parse_args()
    print("Proceeding with the following arguments:")
    print(args)
    # run_data_measurements.py -n hate_speech18 -c default -s train -f text -w npmi
    get_text_label_df(
        args.dataset,
        args.config,
        args.split,
        args.feature,
        args.label_field,
        args.calculation,
        args.out_dir,
        do_html=args.do_html,
        use_cache=args.cached,
    )
    print()


if __name__ == "__main__":
    main()