import re
import json
import records
from typing import List, Dict
from sqlalchemy.exc import SQLAlchemyError
from utils.sql.all_keywords import ALL_KEY_WORDS


class WTQDBEngine:
    def __init__(self, fdb):
        self.db = records.Database('sqlite:///{}'.format(fdb))
        self.conn = self.db.get_connection()

    def execute_wtq_query(self, sql_query: str):
        out = self.conn.query(sql_query)
        results = out.all()
        merged_results = []
        for i in range(len(results)):
            merged_results.extend(results[i].values())
        return merged_results

    def delete_rows(self, row_indices: List[int]):
        sql_queries = [
            "delete from w where id == {}".format(row) for row in row_indices
        ]
        for query in sql_queries:
            self.conn.query(query)


def process_table_structure(_wtq_table_content: Dict, _add_all_column: bool = False):
    # remove id and agg column
    headers = [_.replace("\n", " ").lower() for _ in _wtq_table_content["headers"][2:]]
    header_map = {}
    for i in range(len(headers)):
        header_map["c" + str(i + 1)] = headers[i]
    header_types = _wtq_table_content["types"][2:]

    all_headers = []
    all_header_types = []
    vertical_content = []
    for column_content in _wtq_table_content["contents"][2:]:
        # only take the first one
        if _add_all_column:
            for i in range(len(column_content)):
                column_alias = column_content[i]["col"]
                # do not add the numbered column
                if "_number" in column_alias:
                    continue
                vertical_content.append([str(_).replace("\n", " ").lower() for _ in column_content[i]["data"]])
                if "_" in column_alias:
                    first_slash_pos = column_alias.find("_")
                    column_name = header_map[column_alias[:first_slash_pos]] + " " + \
                                  column_alias[first_slash_pos + 1:].replace("_", " ")
                else:
                    column_name = header_map[column_alias]
                all_headers.append(column_name)
                if column_content[i]["type"] == "TEXT":
                    all_header_types.append("text")
                else:
                    all_header_types.append("number")
        else:
            vertical_content.append([str(_).replace("\n", " ").lower() for _ in column_content[0]["data"]])
    row_content = list(map(list, zip(*vertical_content)))

    if _add_all_column:
        ret_header = all_headers
        ret_types = all_header_types
    else:
        ret_header = headers
        ret_types = header_types
    return {
        "header": ret_header,
        "rows": row_content,
        "types": ret_types,
        "alias": list(_wtq_table_content["is_list"].keys())
    }


def retrieve_wtq_query_answer(_engine, _table_content, _sql_struct: List):
    # do not append id / agg
    headers = _table_content["header"]

    def flatten_sql(_ex_sql_struct: List):
        # [ "Keyword", "select", [] ], [ "Column", "c4", [] ]
        _encode_sql = []
        _execute_sql = []
        for _ex_tuple in _ex_sql_struct:
            keyword = str(_ex_tuple[1])
            # upper the keywords.
            if keyword in ALL_KEY_WORDS:
                keyword = str(keyword).upper()

            # extra column, which we do not need in result
            if keyword == "w" or keyword == "from":
                # add 'FROM w' make it executable
                _encode_sql.append(keyword)
            elif re.fullmatch(r"c\d+(_.+)?", keyword):
                # only take the first part
                index_key = int(keyword.split("_")[0][1:]) - 1
                # wrap it with `` to make it executable
                _encode_sql.append("`{}`".format(headers[index_key]))
            else:
                _encode_sql.append(keyword)
            # c4_list, replace it with the original one
            if "_address" in keyword or "_list" in keyword:
                keyword = re.findall(r"c\d+", keyword)[0]

            _execute_sql.append(keyword)

        return " ".join(_execute_sql), " ".join(_encode_sql)

    _exec_sql_str, _encode_sql_str = flatten_sql(_sql_struct)
    try:
        _sql_answers = _engine.execute_wtq_query(_exec_sql_str)
    except SQLAlchemyError as e:
        _sql_answers = []
    _norm_sql_answers = [str(_).replace("\n", " ") for _ in _sql_answers if _ is not None]
    if "none" in _norm_sql_answers:
        _norm_sql_answers = []
    return _encode_sql_str, _norm_sql_answers, _exec_sql_str


def _load_table_w_page(table_path, page_title_path=None) -> dict:
    """
    attention: the table_path must be the .tsv path.
    Load the WikiTableQuestion from csv file. Result in a dict format like:
    {"header": [header1, header2,...], "rows": [[row11, row12, ...], [row21,...]... [...rownm]]}
    """

    from utils.utils import _load_table

    table_item = _load_table(table_path)

    # Load page title
    if not page_title_path:
        page_title_path = table_path.replace("csv", "page").replace(".tsv", ".json")
    with open(page_title_path, "r") as f:
        page_title = json.load(f)['title']
    table_item['page_title'] = page_title

    return table_item