"""
General utilities.
"""
import json
import os
from typing import List, Union, Dict
from functools import cmp_to_key
import math
from collections.abc import Iterable

from datasets import load_dataset

ROOT_DIR = os.path.join(os.path.dirname(__file__), "../")

def _load_table(table_path) -> dict:
    """
    attention: the table_path must be the .tsv path.
    Load the WikiTableQuestion from csv file. Result in a dict format like:
    {"header": [header1, header2,...], "rows": [[row11, row12, ...], [row21,...]... [...rownm]]}
    """

    def __extract_content(_line: str):
        _vals = [_.replace("\n", " ").strip() for _ in _line.strip("\n").split("\t")]
        return _vals

    with open(table_path, "r") as f:
        lines = f.readlines()

        rows = []
        for i, line in enumerate(lines):
            line = line.strip('\n')
            if i == 0:
                header = line.split("\t")
            else:
                rows.append(__extract_content(line))

    table_item = {"header": header, "rows": rows}

    # Defense assertion
    for i in range(len(rows) - 1):
        if not len(rows[i]) == len(rows[i - 1]):
            raise ValueError('some rows have diff cols.')

    return table_item


def majority_vote(
        nsqls: List,
        pred_answer_list: List,
        allow_none_and_empty_answer: bool = False,
        allow_error_answer: bool = False,
        answer_placeholder: Union[str, int] = '<error|empty>',
        vote_method: str = 'prob',
        answer_biased: Union[str, int] = None,
        answer_biased_weight: float = None,
):
    """
    Determine the final nsql execution answer by majority vote.
    """

    def _compare_answer_vote_simple(a, b):
        """
        First compare occur times. If equal, then compare max nsql logprob.
        """
        if a[1]['count'] > b[1]['count']:
            return 1
        elif a[1]['count'] < b[1]['count']:
            return -1
        else:
            if a[1]['nsqls'][0][1] > b[1]['nsqls'][0][1]:
                return 1
            elif a[1]['nsqls'][0][1] == b[1]['nsqls'][0][1]:
                return 0
            else:
                return -1

    def _compare_answer_vote_with_prob(a, b):
        """
        Compare prob sum.
        """
        return 1 if sum([math.exp(nsql[1]) for nsql in a[1]['nsqls']]) > sum(
            [math.exp(nsql[1]) for nsql in b[1]['nsqls']]) else -1

    # Vote answers
    candi_answer_dict = dict()
    for (nsql, logprob), pred_answer in zip(nsqls, pred_answer_list):
        if allow_none_and_empty_answer:
            if pred_answer == [None] or pred_answer == []:
                pred_answer = [answer_placeholder]
        if allow_error_answer:
            if pred_answer == '<error>':
                pred_answer = [answer_placeholder]

        # Invalid execution results
        if pred_answer == '<error>' or pred_answer == [None] or pred_answer == []:
            continue
        if candi_answer_dict.get(tuple(pred_answer), None) is None:
            candi_answer_dict[tuple(pred_answer)] = {
                'count': 0,
                'nsqls': []
            }
        answer_info = candi_answer_dict.get(tuple(pred_answer), None)
        answer_info['count'] += 1
        answer_info['nsqls'].append([nsql, logprob])

    # All candidates execution errors
    if len(candi_answer_dict) == 0:
        return answer_placeholder, [(nsqls[0][0], nsqls[0][-1])]

    # Sort
    if vote_method == 'simple':
        sorted_candi_answer_list = sorted(list(candi_answer_dict.items()),
                                          key=cmp_to_key(_compare_answer_vote_simple), reverse=True)
    elif vote_method == 'prob':
        sorted_candi_answer_list = sorted(list(candi_answer_dict.items()),
                                          key=cmp_to_key(_compare_answer_vote_with_prob), reverse=True)
    elif vote_method == 'answer_biased':
        # Specifically for Tabfact entailed answer, i.e., `1`.
        # If there exists nsql that produces `1`, we consider it more significant because `0` is very common.
        assert answer_biased_weight is not None and answer_biased_weight > 0
        for answer, answer_dict in candi_answer_dict.items():
            if answer == (answer_biased,):
                answer_dict['count'] *= answer_biased_weight
        sorted_candi_answer_list = sorted(list(candi_answer_dict.items()),
                                          key=cmp_to_key(_compare_answer_vote_simple), reverse=True)
    elif vote_method == 'lf_biased':
        # Assign weights to different types of logic forms (lf) to control interpretability and coverage
        for answer, answer_dict in candi_answer_dict.items():
            count = 0
            for nsql, _ in answer_dict['nsqls']:
                if 'map@' in nsql:
                    count += 10
                elif 'ans@' in nsql:
                    count += 10
                else:
                    count += 1
            answer_dict['count'] = count
        sorted_candi_answer_list = sorted(list(candi_answer_dict.items()),
                                          key=cmp_to_key(_compare_answer_vote_simple), reverse=True)
    else:
        raise ValueError(f"Vote method {vote_method} is not supported.")

    pred_answer_info = sorted_candi_answer_list[0]
    pred_answer, pred_answer_nsqls = list(pred_answer_info[0]), pred_answer_info[1]['nsqls']
    return pred_answer, pred_answer_nsqls


def load_data_split(dataset_to_load, split, data_dir=os.path.join(ROOT_DIR, 'datasets/')):
    dataset_split_loaded = load_dataset(
        path=os.path.join(data_dir, "{}.py".format(dataset_to_load)),
        cache_dir=os.path.join(data_dir, "data"))[split]

    # unify names of keys
    if dataset_to_load in ['wikitq', 'has_squall', 'missing_squall',
                           'wikitq', 'wikitq_sql_solvable', 'wikitq_sql_unsolvable',
                           'wikitq_sql_unsolvable_but_in_squall',
                           'wikitq_scalability_ori',
                           'wikitq_scalability_100rows',
                           'wikitq_scalability_200rows',
                           'wikitq_scalability_500rows',
                           'wikitq_robustness'
                           ]:
        pass
    elif dataset_to_load == 'tab_fact':
        new_dataset_split_loaded = []
        for data_item in dataset_split_loaded:
            data_item['question'] = data_item['statement']
            data_item['answer_text'] = data_item['label']
            data_item['table']['page_title'] = data_item['table']['caption']
            new_dataset_split_loaded.append(data_item)
        dataset_split_loaded = new_dataset_split_loaded
    elif dataset_to_load == 'hybridqa':
        new_dataset_split_loaded = []
        for data_item in dataset_split_loaded:
            data_item['table']['page_title'] = data_item['context'].split(' | ')[0]
            new_dataset_split_loaded.append(data_item)
        dataset_split_loaded = new_dataset_split_loaded
    elif dataset_to_load == 'mmqa':
        new_dataset_split_loaded = []
        for data_item in dataset_split_loaded:
            data_item['table']['page_title'] = data_item['table']['title']
            new_dataset_split_loaded.append(data_item)
        dataset_split_loaded = new_dataset_split_loaded
    else:
        raise ValueError(f'{dataset_to_load} dataset is not supported now.')
    return dataset_split_loaded


def pprint_dict(dic):
    print(json.dumps(dic, indent=2))


def flatten(nested_list):
    for x in nested_list:
        if isinstance(x, Iterable) and not isinstance(x, (str, bytes)):
            yield from flatten(x)
        else:
            yield x