import argparse
import json
from utils.sql.process_sql import (
  tokenize, CLAUSE_KEYWORDS, WHERE_OPS, COND_OPS, UNIT_OPS, AGG_OPS,
  JOIN_KEYWORDS, ORDER_OPS, skip_semicolon, SQL_OPS)
KEPT_WHERE_OP = ('not', 'in', 'exists')


def parse_table_unit(toks, start_idx, tables_with_alias):
  idx = start_idx
  len_ = len(toks)
  key = toks[idx]

  if idx + 1 < len_ and toks[idx + 1] == "as":
    tables_with_alias[toks[idx + 2]] = toks[idx]
    idx += 3
  else:
    idx += 1

  return idx, key

def parse_col(toks, start_idx, tables_with_alias, schema, default_tables=None):
  """
      :returns next idx, column id
  """
  tok = toks[start_idx]
  if tok == "*":
    return start_idx + 1

  if '.' in tok:  # if token is a composite
    alias, col = tok.split('.')
    # key = tables_with_alias[alias] + "." + col
    table = tables_with_alias[alias]
    """
    Add schema
    """
    if table not in schema:
      schema[table] = []
    schema[table].append(col)
    # We also want to normalize the column
    toks[start_idx] = "{}.{}".format(table, col)
    """
    END
    """
    return start_idx + 1

  assert default_tables is not None and len(default_tables) > 0, "Default tables should not be None or empty"

  # assert len(default_tables) == 1, "Default table should only have one time"

  """
  Add schema
  """
  # Find the best table here
  def choose_best_table(default_tables, tok):
    lower_tok = tok.lower()
    candidate = process.extractOne(lower_tok, [table.lower() for table in default_tables])[0]
    return candidate

  if len(default_tables) != 1:
    # print(default_tables)
    table = choose_best_table(default_tables, tok)
    # assert len(default_tables) == 1, "Default table should only have one time"
  else:
    table = default_tables[0]
  if table not in schema:
    schema[table] = []
  schema[table].append(tok)
  toks[start_idx] = "{}.{}".format(table, tok)
  return start_idx + 1

  # for alias in default_tables:
  #   table = tables_with_alias[alias]
  #   if tok in schema.schema[table]:
  #     key = table + "." + tok
  #     return start_idx + 1, schema.idMap[key]

  # assert False, "Error col: {}".format(tok)

def parse_col_unit(toks, start_idx, tables_with_alias, schema, default_tables=None, end_idx=None):
  """
      :returns next idx, (agg_op id, col_id)
  """
  idx = start_idx
  if end_idx is not None:
    len_ = len(toks[start_idx:end_idx])
  else:
    len_ = len(toks)
  isBlock = False
  isDistinct = False
  if toks[idx] == '(':
    isBlock = True
    idx += 1

  if toks[idx] in AGG_OPS:
    agg_id = AGG_OPS.index(toks[idx])
    idx += 1
    assert idx < len_ and toks[idx] == '('
    idx += 1
    if toks[idx] == "distinct":
      idx += 1
      isDistinct = True
    idx = parse_col(toks, idx, tables_with_alias, schema, default_tables)
    assert idx < len_ and toks[idx] == ')'
    idx += 1
    return idx

  if toks[idx] == "distinct":
    idx += 1
    isDistinct = True
  agg_id = AGG_OPS.index("none")
  idx = parse_col(toks, idx, tables_with_alias, schema, default_tables)

  if isBlock:
    assert toks[idx] == ')'
    idx += 1  # skip ')'

  return idx

def parse_val_unit(toks, start_idx, tables_with_alias, schema, default_tables=None):
  idx = start_idx
  len_ = len(toks)
  isBlock = False
  if toks[idx] == '(':
    isBlock = True
    idx += 1

  col_unit1 = None
  col_unit2 = None
  unit_op = UNIT_OPS.index('none')

  idx = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
  if idx < len_ and toks[idx] in UNIT_OPS:
    unit_op = UNIT_OPS.index(toks[idx])
    idx += 1
    idx = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)

  if isBlock:
    assert toks[idx] == ')'
    idx += 1  # skip ')'

  return idx

def parse_value(toks, start_idx, tables_with_alias, schema, default_tables=None):
  idx = start_idx
  len_ = len(toks)

  isBlock = False
  if toks[idx] == '(':
    isBlock = True
    idx += 1

  if toks[idx] == 'select':
    idx = parse_sql(toks, idx, schema)
  elif "\"" in toks[idx]:  # token is a string value
    val = toks[idx]
    # Replace with placeholder
    toks[idx] = "_str_value_"
    idx += 1
  else:
    try:
      val = float(toks[idx])
      toks[idx] = "_num_value_"
      idx += 1
    except:
      end_idx = idx
      while end_idx < len_ and toks[end_idx] != ',' and toks[end_idx] != ')' \
              and toks[end_idx] != 'and' and toks[end_idx] not in CLAUSE_KEYWORDS and toks[
        end_idx] not in JOIN_KEYWORDS:
        end_idx += 1

      # idx = parse_col_unit(toks[start_idx: end_idx], 0, tables_with_alias, schema, default_tables)
      idx = parse_col_unit(toks, start_idx, tables_with_alias, schema, default_tables, end_idx=end_idx)
      idx = end_idx

  if isBlock:
    assert toks[idx] == ')'
    idx += 1

  return idx

def parse_condition(toks, start_idx, tables_with_alias, schema, default_tables=None):
  idx = start_idx
  len_ = len(toks)
  # conds = []

  while idx < len_:
    idx = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
    not_op = False
    if toks[idx] == 'not':
      not_op = True
      idx += 1

    assert idx < len_ and toks[idx] in WHERE_OPS, "Error condition: idx: {}, tok: {}".format(idx, toks[idx])
    op_id = WHERE_OPS.index(toks[idx])
    idx += 1
    val1 = val2 = None
    if op_id == WHERE_OPS.index('between'):  # between..and... special case: dual values
      idx = parse_value(toks, idx, tables_with_alias, schema, default_tables)
      assert toks[idx] == 'and'
      idx += 1
      idx = parse_value(toks, idx, tables_with_alias, schema, default_tables)
    else:  # normal case: single value
      idx = parse_value(toks, idx, tables_with_alias, schema, default_tables)
      val2 = None

    # conds.append((not_op, op_id, val_unit, val1, val2))

    if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";") or toks[idx] in JOIN_KEYWORDS):
      break

    if idx < len_ and toks[idx] in COND_OPS:
      # conds.append(toks[idx])
      idx += 1  # skip and/or
  return idx# , conds


def parse_from(toks, start_idx, schema):
  assert 'from' in toks[start_idx:], "'from' not found"
  tables_with_alias = {}

  len_ = len(toks)
  idx = toks.index('from', start_idx) + 1
  default_tables = []
  table_units = []
  conds = []
  # print(idx, len_)
  while idx < len_:
    # print("idx", idx, toks[idx])
    isBlock = False
    if toks[idx] == '(':
      isBlock = True
      idx += 1

    if toks[idx] == 'select':
      idx = parse_sql(toks, idx, schema)
      # table_units.append((TABLE_TYPE['sql'], sql))
    else:
      if idx < len_ and toks[idx] == 'join':
        idx += 1  # skip join
      idx, table_name = parse_table_unit(toks, idx, tables_with_alias)
      # print(table_name)
      # table_units.append((TABLE_TYPE['table_unit'], table_unit))
      default_tables.append(table_name)
      """
      Add schema
      """
      if table_name not in schema:
        schema[table_name] = []
      """
      END
      """

    if idx < len_ and toks[idx] == "on":
      idx += 1  # skip on
      idx = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
      # if len(conds) > 0:
      #   conds.append('and')
      # conds.extend(this_conds)

    if isBlock:
      assert toks[idx] == ')'
      idx += 1

    if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
      break

  return idx, default_tables, tables_with_alias

def parse_select(toks, start_idx, tables_with_alias, schema, default_tables=None):
  idx = start_idx
  len_ = len(toks)

  assert toks[idx] == 'select', "'select' not found"
  idx += 1
  isDistinct = False
  if idx < len_ and toks[idx] == 'distinct':
    idx += 1
    isDistinct = True
  val_units = []

  while idx < len_ and toks[idx] not in CLAUSE_KEYWORDS:
    agg_id = AGG_OPS.index("none")
    if toks[idx] in AGG_OPS:
      agg_id = AGG_OPS.index(toks[idx])
      idx += 1
    idx = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
    # val_units.append((agg_id, val_unit))
    if idx < len_ and toks[idx] == ',':
      idx += 1  # skip ','

  return idx

def parse_where(toks, start_idx, tables_with_alias, schema, default_tables):
  idx = start_idx
  len_ = len(toks)

  if idx >= len_ or toks[idx] != 'where':
    return idx

  idx += 1
  idx = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
  return idx

def parse_group_by(toks, start_idx, tables_with_alias, schema, default_tables):
  idx = start_idx
  len_ = len(toks)
  col_units = []

  if idx >= len_ or toks[idx] != 'group':
    return idx

  idx += 1
  assert toks[idx] == 'by'
  idx += 1

  while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
    idx = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
    # col_units.append(col_unit)
    if idx < len_ and toks[idx] == ',':
      idx += 1  # skip ','
    else:
      break

  return idx

def parse_having(toks, start_idx, tables_with_alias, schema, default_tables):
  idx = start_idx
  len_ = len(toks)

  if idx >= len_ or toks[idx] != 'having':
    return idx

  idx += 1
  idx = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
  return idx

def parse_order_by(toks, start_idx, tables_with_alias, schema, default_tables):
  idx = start_idx
  len_ = len(toks)
  val_units = []
  order_type = 'asc'  # default type is 'asc'

  if idx >= len_ or toks[idx] != 'order':
    return idx

  idx += 1
  assert toks[idx] == 'by'
  idx += 1

  while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
    idx = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
    # val_units.append(val_unit)
    if idx < len_ and toks[idx] in ORDER_OPS:
      order_type = toks[idx]
      idx += 1
    if idx < len_ and toks[idx] == ',':
      idx += 1  # skip ','
    else:
      break

  return idx

def parse_limit(toks, start_idx):
  idx = start_idx
  len_ = len(toks)

  if idx < len_ and toks[idx] == 'limit':
    idx += 2
    toks[idx - 1] = "_limit_value_"
    # make limit value can work, cannot assume put 1 as a fake limit number
    if type(toks[idx - 1]) != int:
      return idx

    return idx

  return idx

def parse_sql(toks, start_idx, schema):
  isBlock = False  # indicate whether this is a block of sql/sub-sql
  len_ = len(toks)
  idx = start_idx

  if toks[idx] == '(':
    isBlock = True
    idx += 1

  from_end_idx, default_tables, tables_with_alias = parse_from(toks, start_idx, schema)

  _ = parse_select(toks, idx, tables_with_alias, schema, default_tables)
  idx = from_end_idx

  idx = parse_where(toks, idx, tables_with_alias, schema, default_tables)
  idx = parse_group_by(toks, idx, tables_with_alias, schema, default_tables)
  idx = parse_having(toks, idx, tables_with_alias, schema, default_tables)
  idx = parse_order_by(toks, idx, tables_with_alias, schema, default_tables)
  idx = parse_limit(toks, idx)
  #
  idx = skip_semicolon(toks, idx)
  if isBlock:
    assert toks[idx] == ')'
    idx += 1  # skip ')'
  idx = skip_semicolon(toks, idx)

  # for op in SQL_OPS:  # initialize IUE
  #   sql[op] = None
  if idx < len_ and toks[idx] in SQL_OPS:
    sql_op = toks[idx]
    idx += 1
    idx = parse_sql(toks, idx, schema)
    # sql[sql_op] = IUE_sql
  return idx

def extract_schema_from_sql(schema, sql):
  toks = tokenize(sql)
  parse_sql(toks=toks, start_idx=0, schema=schema)
  return toks

def extract_template_from_sql(sql, schema={}):
  try:
    toks = tokenize(sql)
  except:
    print("Tokenization error for {}".format(sql))
    toks = []
  # print(toks)
  template = []
  # ignore_follow_up_and = False
  len_ = len(toks)
  idx = 0
  while idx < len_:
    tok = toks[idx]
    if tok == "from":
      template.append(tok)
      if toks[idx+1] != "(":
        template.append("[FROM_PART]")
        idx += 1
        while idx < len_ and (toks[idx] not in CLAUSE_KEYWORDS and toks[idx] != ")"):
          idx += 1
        continue
    elif tok in CLAUSE_KEYWORDS:
      template.append(tok)
    elif tok in AGG_OPS:
      template.append(tok)
    elif tok in [",", "*", "(", ")", "having", "by", "distinct"]:
      template.append(tok)
    elif tok in ["asc", "desc"]:
      template.append("[ORDER_DIRECTION]")
    elif tok in WHERE_OPS:
      if tok in KEPT_WHERE_OP:
        template.append(tok)
      else:
        template.append("[WHERE_OP]")
        if tok == "between":
          idx += 2
    elif tok in COND_OPS:
      template.append(tok)
    elif template[-1] == "[WHERE_OP]":
      template.append("[VALUE]")
    elif template[-1] == "limit":
      template.append("[LIMIT_VALUE]")
    elif template[-1] != "[MASK]": # value, schema, join on as
      template.append("[MASK]")
    idx += 1
  return template

def extract_partial_template_from_sql(sql, schema={}):
  toks = tokenize(sql)
  # print(toks)
  template = []
  # ignore_follow_up_and = False
  len_ = len(toks)
  idx = 0
  while idx < len_:
    tok = toks[idx]
    if tok == "from":
      template.append(tok)
      if toks[idx+1] != "(":
        # template.append("[FROM_PART]")
        idx += 1
        while idx < len_ and (toks[idx] not in CLAUSE_KEYWORDS and toks[idx] != ")"):
          template.append(toks[idx])
          idx += 1
        continue
    elif tok in CLAUSE_KEYWORDS:
      template.append(tok)
    elif tok in AGG_OPS:
      template.append(tok)
    elif tok in [",", "*", "(", ")", "having", "by", "distinct"]:
      template.append(tok)
    elif tok in ["asc", "desc"]:
      template.append("[ORDER_DIRECTION]")
    elif tok in WHERE_OPS:
      if tok in KEPT_WHERE_OP:
        template.append(tok)
      else:
        template.append("[WHERE_OP]")
        if tok == "between":
          idx += 2
    elif tok in COND_OPS:
      template.append(tok)
    elif template[-1] == "[WHERE_OP]":
      template.append("[VALUE]")
    elif template[-1] == "limit":
      template.append("[LIMIT_VALUE]")
    else:
      template.append(tok)
    idx += 1
  return template


def is_valid_schema(schema):
  # There is no "." and " " in the column name
  for table in schema:
    if "." in table:
      return False
    if any([keyword == table for keyword in CLAUSE_KEYWORDS]):
      return False
    for column in schema[table]:
      if "." in column or " " in column or '"' in column or "'" in column:
        return False
  return True

def clean_sql(sql):
  while "JOIN JOIN" in sql:
    sql = sql.replace("JOIN JOIN", "JOIN")
  if "JOIN WHERE" in sql:
    sql = sql.replace("JOIN WHERE", "WHERE")
  if "JOIN GROUP BY" in sql:
    sql = sql.replace("JOIN GROUP BY", "GROUP BY")
  return sql

if __name__ == "__main__":
  parser = argparse.ArgumentParser()
  parser.add_argument("--input_file", type=str)
  parser.add_argument("--output_file", type=str)
  parser.add_argument("--mode", type=str, choices=["debug", "verbose", "silent"])
  parser.add_argument("--task", type=str, choices=["template_extraction", "schema_extraction"])
  args = parser.parse_args()

  if args.task == "schema_extraction":
    if args.mode == "debug":
      sql = "SELECT count(*) FROM games"
      sql = sql + " INTERSECT " + "SELECT sacks, year FROM players"
      sql = sql + " EXCEPT " + 'SELECT T1.year, T1.sacks FROM players AS T1 JOIN tackles AS T2 ON T1.id = T2.player_id WHERE T2.manager = "A" and T2.season NOT IN (SELECT season FROM match WHERE match_name = "IVL" INTERSECT SELECT T1.year, T1.sacks FROM sack AS T1) GROUP BY T1.year, T1.sacks HAVING count(T1.coach) > 10 ORDER BY T2.score LIMIT 5'
      sql = "SELECT T1.pld FROM pld AS T1 JOIN games AS T2 ON T1.crs_code = T2.crs_code JOIN GROUP BY T1.pld WHERE T2.gf = '8' AND T2.gf = '9'"
      sql = 'select * from head where height = "6-0" or height = "6-0" order by height asc'
      schema = {}
      extract_schema_from_sql(schema, sql)
      print(schema, is_valid_schema(schema))
    elif args.mode == "verbose":
      fout = open(args.output_file, "w")
      with open(args.input_file) as fin:
        for line in fin:
          example = json.loads(line)
          schema = {}
          try:
            sql = example["sql"] if "sql" in example else example["pred"]
            sql = clean_sql(sql)
            example["sql"] = sql
            extract_schema_from_sql(schema, sql)

          except:
            # print(sql)
            continue
          for table in schema:
            schema[table] = list(set(schema[table]))
          if is_valid_schema(schema):
            example["extracted_schema"] = schema
            fout.write(json.dumps(example) + "\n")
    elif args.mode == "verbose":
      fout = open(args.output_file, "w")
      with open(args.input_file) as fin:
        for line in fin:
          example = json.loads(line)
          schema = {}
          sql = example["sql"] if "sql" in example else example["pred"]
          sql = clean_sql(sql)
          example["sql"] = sql
          extract_schema_from_sql(schema, sql)
          for table in schema:
            schema[table] = list(set(schema[table]))
          example["extracted_schema"] = schema
          fout.write(json.dumps(example) + "\n")
          if is_valid_schema(schema):
            example["extracted_schema"] = schema
            fout.write(json.dumps(example) + "\n")
  elif args.task == "template_extraction":
    if args.mode == "debug":
      sql = "SELECT avg(T1.Votes) FROM seats AS T1 JOIN votes AS T2 ON T1.Seat_ID = T2.Seat_ID WHERE T1.seats BETWEEN 1 AND 2 and T1.Seats = 1 AND T2.Votes = 10"
      print(extract_template_from_sql(sql))
      print(extract_partial_template_from_sql(sql))
    elif args.mode == "verbose":
      fout_json = open(args.output_file + ".json", "w")
      fout_txt = open(args.output_file + ".txt", "w")
      low_freq_txt = open(args.output_file + ".low_freq", "w")
      high_freq_txt = open(args.output_file + ".high_freq", "w")
      all_templates = set()
      # for input_file in args.input_file.split(","):
      templates = {}
      with open(args.input_file) as fin:
        for line in fin:
          example = json.loads(line)
          sql = example["sql"] if "sql" in example else example["pred"]
          if isinstance(sql, list):
            sql = sql[-1]
          template = extract_template_from_sql(sql)
          template_str = " ".join(template)
          if template_str not in templates:
            templates[template_str] = []
          templates[template_str].append(sql)
      print("{} has template {}".format(args.input_file, len(templates)))

      json.dump(templates, fout_json)
      for template in sorted(templates.keys()):
        if len(templates[template]) > 1:
          high_freq_txt.write(template + "\n")
        else:
          low_freq_txt.write(template + "\n")
        fout_txt.write(template + "\n")