import json
import os
import uuid
import pandas as pd
import streamlit as st
import argparse
import traceback
from typing import Dict
import requests
from utils.utils import load_data_split
from nsql.database import NeuralDB
from nsql.nsql_exec import NSQLExecutor
from nsql.nsql_exec_python import NPythonExecutor
from generation.generator import Generator
import time

st.set_page_config(
    page_title="Binder Demo",
    page_icon="🔗",
    layout="wide",
    initial_sidebar_state="expanded",
    menu_items={
        'About': "Check out our [website](https://lm-code-binder.github.io/) for more details!"
    }
)

ROOT_DIR = os.path.join(os.path.dirname(__file__), "./")
# todo: Add more binder questions, need careful cherry-picks
EXAMPLE_TABLES = {
    "Estonia men's national volleyball team": (558, "what are the total number of players from france?"),
    # 'how old is kert toobal'
    "Highest mountain peaks of California": (5, "which is the lowest mountain?"),
    # 'which mountain is in the most north place?'
    "2010–11 UAB Blazers men's basketball team": (1, "how many players come from alabama?"),
    # 'how many players are born after 1996?'
    "Nissan SR20DET": (438, "which car has power more than 170 kw?"),
    # ''
}


@st.cache
def load_data():
    return load_data_split("missing_squall", "validation")


@st.cache
def get_key():
    # print the public IP of the demo machine
    ip = requests.get('https://checkip.amazonaws.com').text.strip()
    print(ip)

    URL = "http://54.242.37.195:8080/api/predict"
    # The springboard machine we built to protect the key, 20217 is the birthday of Tianbao's girlfriend
    # we will only let the demo machine have the access to the keys

    one_key = requests.post(url=URL, json={"data": "Hi, binder server. Give me a key!"}).json()['data'][0]
    return one_key


def read_markdown(path):
    with open(path, "r") as f:
        output = f.read()
    st.markdown(output, unsafe_allow_html=True)


def generate_binder_program(_args, _generator, _data_item):
    n_shots = _args.n_shots
    few_shot_prompt = _generator.build_few_shot_prompt_from_file(
        file_path=_args.prompt_file,
        n_shots=n_shots
    )
    generate_prompt = _generator.build_generate_prompt(
        data_item=_data_item,
        generate_type=(_args.generate_type,)
    )
    prompt = few_shot_prompt + "\n\n" + generate_prompt

    # Ensure the input length fit Codex max input tokens by shrinking the n_shots
    max_prompt_tokens = _args.max_api_total_tokens - _args.max_generation_tokens
    from transformers import AutoTokenizer
    tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=os.path.join(ROOT_DIR, "utils", "gpt2"))
    while len(tokenizer.tokenize(prompt)) >= max_prompt_tokens:
        n_shots -= 1
        assert n_shots >= 0
        few_shot_prompt = _generator.build_few_shot_prompt_from_file(
            file_path=_args.prompt_file,
            n_shots=n_shots
        )
        prompt = few_shot_prompt + "\n\n" + generate_prompt

    response_dict = _generator.generate_one_pass(
        prompts=[("0", prompt)],  # the "0" is the place taker, take effect only when there are multi threads
        verbose=_args.verbose
    )
    print(response_dict)
    return response_dict["0"][0][0]


# Set up
import nltk

nltk.download('punkt')
parser = argparse.ArgumentParser()

parser.add_argument('--prompt_file', type=str, default='templates/prompts/prompt_wikitq_v3.txt')
# Binder program generation options
parser.add_argument('--prompt_style', type=str, default='create_table_select_3_full_table',
                    choices=['create_table_select_3_full_table',
                             'create_table_select_full_table',
                             'create_table_select_3',
                             'create_table',
                             'create_table_select_3_full_table_w_all_passage_image',
                             'create_table_select_3_full_table_w_gold_passage_image',
                             'no_table'])
parser.add_argument('--generate_type', type=str, default='nsql',
                    choices=['nsql', 'sql', 'answer', 'npython', 'python'])
parser.add_argument('--n_shots', type=int, default=14)
parser.add_argument('--seed', type=int, default=42)

# Codex options
# todo: Allow adjusting Codex parameters
parser.add_argument('--engine', type=str, default="code-davinci-002")
parser.add_argument('--max_generation_tokens', type=int, default=512)
parser.add_argument('--max_api_total_tokens', type=int, default=8001)
parser.add_argument('--temperature', type=float, default=0.)
parser.add_argument('--sampling_n', type=int, default=1)
parser.add_argument('--top_p', type=float, default=1.0)
parser.add_argument('--stop_tokens', type=str, default='\n\n',
                    help='Split stop tokens by ||')
parser.add_argument('--qa_retrieve_pool_file', type=str, default='templates/qa_retrieve_pool.json')

# debug options
parser.add_argument('-v', '--verbose', action='store_false')
args = parser.parse_args()
keys = [get_key()]

# The title
st.markdown("# Binder Playground")

# Demo description
read_markdown('resources/demo_description.md')

# Upload tables/Switch tables

st.markdown('### Try Binder!')
col1, _ = st.columns(2)
with col1:
    selected_table_title = st.selectbox(
        "Select an example table",
        (
            "Estonia men's national volleyball team",
            "Highest mountain peaks of California",
            "2010–11 UAB Blazers men's basketball team",
            "Nissan SR20DET",
        )
    )

# Here we just use ourselves'
data_items = load_data()
data_item = data_items[EXAMPLE_TABLES[selected_table_title][0]]
table = data_item['table']
header, rows, title = table['header'], table['rows'], table['page_title']
db = NeuralDB(
    [{"title": title, "table": table}])  # todo: try to cache this db instead of re-creating it again and again.
df = db.get_table_df()
st.markdown("Title: {}".format(title))
st.dataframe(df)

# Let user input the question
with col1:
    selected_language = st.selectbox(
        "Select a programming language",
        ("SQL", "Python"),
    )
if selected_language == 'SQL':
    args.prompt_file = 'templates/prompts/prompt_wikitq_v3.txt'
    args.generate_type = 'nsql'
elif selected_language == 'Python':
    args.prompt_file = 'templates/prompts/prompt_wikitq_python_simplified_v4.txt'
    args.generate_type = 'npython'
else:
    raise ValueError(f'{selected_language} language is not supported.')

question = st.text_input(
    "Ask a question about the table:(Please press enter at last!!)",
    placeholder=EXAMPLE_TABLES[selected_table_title][1],
)

if not question:
    st.stop()

# Generate Binder Program
generator = Generator(args, keys=keys)
with st.spinner("Generating program ..."):
    binder_program = generate_binder_program(args, generator,
                                             {"question": question, "table": db.get_table_df(), "title": title})

# Do execution
st.subheader("Binder program")
if selected_language == 'SQL':
    st.markdown('```sql\n' + binder_program + '\n```')
    # st.markdown('```' + binder_program + '```')
    # with st.container():
    #     st.write(binder_program)
    executor = NSQLExecutor(args, keys=keys)
elif selected_language == 'Python':
    st.code(binder_program, language='python')
    executor = NPythonExecutor(args, keys=keys)
    db = db.get_table_df()
else:
    raise ValueError(f'{selected_language} language is not supported.')
try:
    stamp = '{}'.format(uuid.uuid4())
    os.makedirs('tmp_for_vis/', exist_ok=True)
    with st.spinner("Executing program ..."):
        exec_answer = executor.nsql_exec(stamp, binder_program, db)
    if selected_language == 'SQL':
        with open("tmp_for_vis/{}_tmp_for_vis_steps.txt".format(stamp), "r") as f:
            steps = json.load(f)
        col1, col2, col3 = st.columns([4.7, 0.6, 4.7])
        # col1.subheader('Steps')
        # col3.subheader('Intermediate results')
        for i, step in enumerate(steps):
            col1, _, _ = st.columns([4.7, 0.6, 4.7])
            with col1:
                st.markdown(f'**Step#{i + 1}**')
            col1, col2, col3 = st.columns([4.7, 0.6, 4.7])
            with col1:
                st.markdown('```sql\n' + step + '\n```')
                # st.markdown('```' + step + '```')
                # with st.container():
                #     st.write(step)
            with col2:
                st.markdown('$\\rightarrow$')
            with st.spinner('...'):
                time.sleep(1)
            with open("tmp_for_vis/{}_result_step_{}.txt".format(stamp, i), "r") as f:
                result_in_this_step = json.load(f)
            with col3:
                if isinstance(result_in_this_step, Dict):
                    rows = result_in_this_step["rows"]
                    header = result_in_this_step["header"]
                    if isinstance(header, list):
                        for idx in range(len(header)):
                            if header[idx].startswith('col_'):
                                header[idx] = step
                    st.dataframe(pd.DataFrame(pd.DataFrame(rows, columns=header)), use_container_width=True)
                else:
                    st.markdown(result_in_this_step)
            with st.spinner('...'):
                time.sleep(1)
    elif selected_language == 'Python':
        pass
    if isinstance(exec_answer, list) and len(exec_answer) == 1:
        exec_answer = exec_answer[0]
    # st.subheader(f'Execution answer')
    st.text('')
    st.markdown(f"Execution answer: {exec_answer}")
    # todo: Remove tmp files
except Exception as e:
    traceback.print_exc()