import gradio as gr
from modeling import global_config, ToyTransformer, AttentionBackend
import torch
from tokenizers import TRIETokenizer
from threading import Thread
import bisect

if torch.cuda.is_available():
    g_device = torch.device('cpu')
else:
    g_device = torch.device('cpu')
global_config['attn_backend'] = AttentionBackend.Naive

g_SEQ_LEN = 1024
g_HIDDEN_SIZE = 768
g_NUM_HEADS = 12
g_NUM_LAYERS = 12
g_DTYPE = torch.float32

g_tokenizer = TRIETokenizer('llama_vocab_pruned_32k.json')
g_model = ToyTransformer(g_tokenizer.get_vocab_size(), g_NUM_LAYERS, g_NUM_HEADS, g_HIDDEN_SIZE, g_SEQ_LEN, g_device, g_DTYPE)

g_model.load_state_dict(torch.load('model.pt', map_location='cpu'))


def generate(model, tokenizer, prompt, temperature, top_p, rep_penalty,
             max_new_tokens=20, total_tokens=None,
             end_tokens=None,
             enable_kv_cache=True):
    model.eval()

    feed_tokens = tokenizer.encode(prompt) if isinstance(prompt, str) else prompt

    all_tokens = feed_tokens.copy()
    if total_tokens is not None:
        max_new_tokens = max(0, total_tokens - len(feed_tokens))

    with torch.no_grad():
        kv_cache = None
        for _ in range(max_new_tokens):
            logits, kv_cache = model.forward(
                torch.tensor([feed_tokens if enable_kv_cache else all_tokens]).to(model.device),
                kv_cache=kv_cache)
            logits = logits[0][-1].cpu()
            if not enable_kv_cache:
                kv_cache = None

            # apply repetition penalty
            logits_rep = torch.gather(logits, 0, torch.tensor(all_tokens))
            logits_rep = torch.where(logits_rep < 0, logits_rep * rep_penalty, logits_rep / rep_penalty)
            logits.scatter_(0, torch.tensor(all_tokens), logits_rep)

            # apply temperature
            logits /= max(temperature, 1e-6)

            probs = torch.softmax(logits, dim=0)

            # apply top-p
            ordered_probs, ordered_indices = torch.sort(probs, descending=True)
            cum_probs = torch.cumsum(ordered_probs, dim=0).tolist()
            top_p_index = bisect.bisect_right(cum_probs, top_p) + 1
            ordered_probs, ordered_indices = ordered_probs[:top_p_index], ordered_indices[:top_p_index]
            sampled_index = ordered_indices[torch.multinomial(ordered_probs, num_samples=1).item()].item()

            all_tokens.append(sampled_index)
            feed_tokens = [sampled_index]

            if end_tokens is not None and sampled_index in end_tokens:
                break

            yield feed_tokens
    return


def predict(user_input, history, max_length, top_p, temperature, rep_penalty, retry):
    if retry and len(history) == 0:
        yield []
        return
    elif retry:
        user_input = history[-1][0]
        history = history[:-1]

    history.append((user_input, ""))

    encoded_inputs = [(g_tokenizer.encode('User:' + h[0]), g_tokenizer.encode('Assistant:' + h[1])) for h in history]
    taken_rounds, taken_rounds_length = [], 0
    while len(taken_rounds) < len(encoded_inputs):
        round_pair = encoded_inputs[len(encoded_inputs) - 1 - len(taken_rounds)]
        if len(round_pair[0]) + len(round_pair[1]) + taken_rounds_length >= g_SEQ_LEN - max_length:
            break
        taken_rounds.append(round_pair)
        taken_rounds_length += len(round_pair[0]) + len(round_pair[1])
    taken_rounds = taken_rounds[::-1]

    input_tokens = g_tokenizer.encode('<s>A chat between User and Assistant.')
    for round_pair in taken_rounds:
        input_tokens += g_tokenizer.encode('\n') + round_pair[0] + g_tokenizer.encode('\n') + round_pair[1]
    # print(taken_rounds, g_tokenizer.decode(input_tokens))
    for response in generate(g_model, g_tokenizer, input_tokens, temperature, top_p, rep_penalty, max_length, end_tokens=g_tokenizer.encode('</s>')):
        history[-1] = (history[-1][0], history[-1][1] + g_tokenizer.decode(response))
        yield history


def main():
    css = '''
        .contain {max-width:50}

        #chatbot {min-height:500px}
    '''

    with gr.Blocks(css=css) as demo:
        gr.HTML('<h1 align="center">ToyTransformer</h1><h5 align="center">(Note: Please refresh if the page is not responsive.)</h5>')

        chatbot = gr.Chatbot(elem_id='chatbot')
        with gr.Column():
            user_input = gr.Textbox(show_label=False, placeholder="Input", lines=1, container=False)
            with gr.Row():
                submitBtn = gr.Button("Send", variant="primary")
                retryBtn = gr.Button("Retry")
                cancelBtn = gr.Button('Undo')
                emptyBtn = gr.Button("Clear")
            with gr.Row():
                max_length = gr.Slider(0, 512, value=200, step=1, label="Max Response Tokens", interactive=True)
                top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top-P", interactive=True)
                temperature = gr.Slider(0, 1, value=0.5, step=0.01, label="Temperature", interactive=True)
                rep_penalty = gr.Slider(1.0, 1.5, value=1.1, step=0.01, label='Repetition Penalty', interactive=True)

        submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, rep_penalty, gr.State(False)],
                        [chatbot], show_progress=False)
        submitBtn.click(lambda: '', [], [user_input], show_progress=False)

        retryBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, rep_penalty, gr.State(True)],
                       [chatbot], show_progress=False)

        cancelBtn.click(lambda m: m[:-1], [chatbot], [chatbot], show_progress=False)

        emptyBtn.click(lambda: [], outputs=[chatbot], show_progress=False)

    demo.queue().launch(share=False, inbrowser=True)


main()