import copy
import os
import sqlite3
import records
import sqlalchemy
import pandas as pd
from typing import Dict, List
import uuid

from utils.normalizer import convert_df_type, prepare_df_for_neuraldb_from_table
from utils.mmqa.image_stuff import get_caption


def check_in_and_return(key: str, source: dict):
    # `` wrapped means as a whole
    if key.startswith("`") and key.endswith("`"):
        key = key[1:-1]
    if key in source.keys():
        return source[key]
    else:
        for _k, _v in source.items():
            if _k.lower() == key.lower():
                return _v
        raise ValueError("{} not in {}".format(key, source))


class NeuralDB(object):
    def __init__(self, tables: List[Dict[str, Dict]], passages=None, images=None):
        self.raw_tables = copy.deepcopy(tables)
        self.passages = {}
        self.images = {}
        self.image_captions = {}
        self.passage_linker = {}  # The links from cell value to passage
        self.image_linker = {}  # The links from cell value to images

        # Get passages
        if passages:
            for passage in passages:
                title, passage_content = passage['title'], passage['text']
                self.passages[title] = passage_content

        # Get images
        if images:
            for image in images:
                _id, title, picture = image['id'], image['title'], image['pic']
                self.images[title] = picture
                self.image_captions[title] = get_caption(_id)

        # Link grounding resources from other modalities(passages, images).
        if self.raw_tables[0]['table'].get('rows_with_links', None):
            rows = self.raw_tables[0]['table']['rows']
            rows_with_links = self.raw_tables[0]['table']['rows_with_links']

            link_title2cell_map = {}
            for row_id in range(len(rows)):
                for col_id in range(len(rows[row_id])):
                    cell = rows_with_links[row_id][col_id]
                    for text, title, url in zip(cell[0], cell[1], cell[2]):
                        text = text.lower().strip()
                        link_title2cell_map[title] = text

            # Link Passages
            for passage in passages:
                title, passage_content = passage['title'], passage['text']
                linked_cell = link_title2cell_map.get(title, None)
                if linked_cell:
                    self.passage_linker[linked_cell] = title

            # Images
            for image in images:
                title, picture = image['title'], image['pic']
                linked_cell = link_title2cell_map.get(title, None)
                if linked_cell:
                    self.image_linker[linked_cell] = title

        for table_info in tables:
            table_info['table'] = prepare_df_for_neuraldb_from_table(table_info['table'])

        self.tables = tables

        # Connect to SQLite database
        self.tmp_path = "tmp"
        os.makedirs(self.tmp_path, exist_ok=True)
        # self.db_path = os.path.join(self.tmp_path, '{}.db'.format(hash(time.time())))
        self.db_path = os.path.join(self.tmp_path, '{}.db'.format(uuid.uuid4()))
        self.sqlite_conn = sqlite3.connect(self.db_path)

        # Create DB
        assert len(tables) >= 1, "DB has no table inside"
        table_0 = tables[0]
        if len(tables) > 1:
            raise ValueError("More than one table not support yet.")
        else:
            table_0["table"].to_sql("w", self.sqlite_conn)
            self.table_name = "w"
            self.table_title = table_0.get('title', None)

        # Records conn
        self.db = records.Database('sqlite:///{}'.format(self.db_path))
        self.records_conn = self.db.get_connection()

    def __str__(self):
        return str(self.execute_query("SELECT * FROM {}".format(self.table_name)))

    def get_table(self, table_name=None):
        table_name = self.table_name if not table_name else table_name
        sql_query = "SELECT * FROM {}".format(table_name)
        _table = self.execute_query(sql_query)
        return _table

    def get_header(self, table_name=None):
        _table = self.get_table(table_name)
        return _table['header']

    def get_rows(self, table_name):
        _table = self.get_table(table_name)
        return _table['rows']

    def get_table_df(self):
        return self.tables[0]['table']

    def get_table_raw(self):
        return self.raw_tables[0]['table']

    def get_table_title(self):
        return self.tables[0]['title']

    def get_passages_titles(self):
        return list(self.passages.keys())

    def get_images_titles(self):
        return list(self.images.keys())

    def get_passage_by_title(self, title: str):
        return check_in_and_return(title, self.passages)

    def get_image_by_title(self, title):
        return check_in_and_return(title, self.images)

    def get_image_caption_by_title(self, title):
        return check_in_and_return(title, self.image_captions)

    def get_image_linker(self):
        return copy.deepcopy(self.image_linker)

    def get_passage_linker(self):
        return copy.deepcopy(self.passage_linker)

    def execute_query(self, sql_query: str):
        """
        Basic operation. Execute the sql query on the database we hold.
        @param sql_query:
        @return:
        """
        # When the sql query is a column name (@deprecated: or a certain value with '' and "" surrounded).
        if len(sql_query.split(' ')) == 1 or (sql_query.startswith('`') and sql_query.endswith('`')):
            col_name = sql_query
            new_sql_query = r"SELECT row_id, {} FROM {}".format(col_name, self.table_name)
            # Here we use a hack that when a value is surrounded by '' or "", the sql will return a column of the value,
            # while for variable, no ''/"" surrounded, this sql will query for the column.
            out = self.records_conn.query(new_sql_query)
        # When the sql query wants all cols or col_id, which is no need for us to add 'row_id'.
        elif sql_query.lower().startswith("select *") or sql_query.startswith("select col_id"):
            out = self.records_conn.query(sql_query)
        else:
            try:
                # SELECT row_id in addition, needed for result and old table alignment.
                new_sql_query = "SELECT row_id, " + sql_query[7:]
                out = self.records_conn.query(new_sql_query)
            except sqlalchemy.exc.OperationalError as e:
                # Execute normal SQL, and in this case the row_id is actually in no need.
                out = self.records_conn.query(sql_query)

        results = out.all()
        unmerged_results = []
        merged_results = []

        headers = out.dataset.headers
        for i in range(len(results)):
            unmerged_results.append(list(results[i].values()))
            merged_results.extend(results[i].values())

        return {"header": headers, "rows": unmerged_results}

    def add_sub_table(self, sub_table, table_name=None, verbose=True):
        """
        Add sub_table into the table.
        @return:
        """
        table_name = self.table_name if not table_name else table_name
        sql_query = "SELECT * FROM {}".format(table_name)
        oring_table = self.execute_query(sql_query)
        old_table = pd.DataFrame(oring_table["rows"], columns=oring_table["header"])
        # concat the new column into old table
        sub_table_df_normed = convert_df_type(pd.DataFrame(data=sub_table['rows'], columns=sub_table['header']))
        new_table = old_table.merge(sub_table_df_normed,
                                    how='left', on='row_id')  # do left join
        new_table.to_sql(table_name, self.sqlite_conn, if_exists='replace',
                         index=False)
        if verbose:
            print("Insert column(s) {} (dtypes: {}) into table.\n".format(', '.join([_ for _ in sub_table['header']]),
                                                                          sub_table_df_normed.dtypes))