# import gradio as gr

# model_name = "models/THUDM/chatglm2-6b-int4"
# gr.load(model_name).lauch()

# %%writefile demo-4bit.py

from textwrap import dedent

# credit to https://github.com/THUDM/ChatGLM2-6B/blob/main/web_demo.py
from transformers import AutoModel, AutoTokenizer
import gradio as gr
import mdtex2html

from loguru import logger

model_name = "THUDM/chatglm2-6b"
model_name = "THUDM/chatglm2-6b-int4"

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

# model = AutoModel.from_pretrained(model_name, trust_remote_code=True).cuda()

# 按需修改,目前只支持 4/8 bit 量化
# model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).quantize(4).cuda()

import torch

has_cuda = torch.cuda.is_available()
# has_cuda = False  # force cpu

if has_cuda:
    model = AutoModel.from_pretrained(model_name, trust_remote_code=True).cuda()  # 3.92G
else:
    model = AutoModel.from_pretrained(model_name, trust_remote_code=True).half()  # .float() .half().float()

model = model.eval()

_ = """Override Chatbot.postprocess"""

def postprocess(self, y):
    if y is None:
        return []
    for i, (message, response) in enumerate(y):
        y[i] = (
            None if message is None else mdtex2html.convert((message)),
            None if response is None else mdtex2html.convert(response),
        )
    return y


gr.Chatbot.postprocess = postprocess


def parse_text(text):
    """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
    lines = text.split("\n")
    lines = [line for line in lines if line != ""]
    count = 0
    for i, line in enumerate(lines):
        if "```" in line:
            count += 1
            items = line.split('`')
            if count % 2 == 1:
                lines[i] = f'<pre><code class="language-{items[-1]}">'
            else:
                lines[i] = f'<br></code></pre>'
        else:
            if i > 0:
                if count % 2 == 1:
                    line = line.replace("`", "\`")
                    line = line.replace("<", "&lt;")
                    line = line.replace(">", "&gt;")
                    line = line.replace(" ", "&nbsp;")
                    line = line.replace("*", "&ast;")
                    line = line.replace("_", "&lowbar;")
                    line = line.replace("-", "&#45;")
                    line = line.replace(".", "&#46;")
                    line = line.replace("!", "&#33;")
                    line = line.replace("(", "&#40;")
                    line = line.replace(")", "&#41;")
                    line = line.replace("$", "&#36;")
                lines[i] = "<br>"+line
    text = "".join(lines)
    return text


def predict(input, chatbot, max_length, top_p, temperature, history, past_key_values):
    chatbot.append((parse_text(input), ""))
    for response, history, past_key_values in model.stream_chat(tokenizer, input, history, past_key_values=past_key_values,
                                                                return_past_key_values=True,
                                                                max_length=max_length, top_p=top_p,
                                                                temperature=temperature):
        chatbot[-1] = (parse_text(input), parse_text(response))

        yield chatbot, history, past_key_values


def trans_api(input, max_length=4096, top_p=0.8, temperature=0.2):
    try:
        res = model.stream_chat(
            tokenizer, 
            input, 
            history=[], 
            past_key_values=None,
            return_past_key_values=False,
            max_length=max_length,
            top_p=top_p,
            temperature=temperature,
        )
        logger.debug(f"{res=}")
    except Exception as exc:
        logger.error(exc)
        

def reset_user_input():
    return gr.update(value='')


def reset_state():
    return [], [], None


with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.HTML("""<h1 align="center">ChatGLM2-6B-int4</h1>""")
    with gr.Accordion("Info", open=False):
        _ = """
            A query takes from 30 seconds to a few tens of seconds, dependent on the number of words/characters 
            the question and answer contain.

            * Low temperature: responses will be more deterministic and focused; High temperature: responses more creative.
            
            * Suggested temperatures -- translation: up to 0.3; chatting: > 0.4

            * Top P controls dynamic vocabulary selection based on context. 

            For a table of example values for different scenarios, refer to [this](https://community.openai.com/t/cheat-sheet-mastering-temperature-and-top-p-in-chatgpt-api-a-few-tips-and-tricks-on-controlling-the-creativity-deterministic-output-of-prompt-responses/172683)

            If the instance is not on a GPU (T4), it will be very slow. You can try to run the colab notebook [chatglm2-6b-4bit colab notebook](https://colab.research.google.com/drive/1WkF7kOjVCcBBatDHjaGkuJHnPdMWNtbW?usp=sharing) for a spin.

            The T4 GPU is sponsored by a community GPU grant from Huggingface. Thanks a lot!
            """
        gr.Markdown(dedent(_))
    chatbot = gr.Chatbot()
    with gr.Row():
        with gr.Column(scale=4):
            with gr.Column(scale=12):
                user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
                    container=False)
            with gr.Column(min_width=32, scale=1):
                submitBtn = gr.Button("Submit", variant="primary")
        with gr.Column(scale=1):
            emptyBtn = gr.Button("Clear History")
            max_length = gr.Slider(0, 32768, value=8192/2, step=1.0, label="Maximum length", interactive=True)
            top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True)
            temperature = gr.Slider(0.01, 1, value=0.95, step=0.01, label="Temperature", interactive=True)

    history = gr.State([])
    past_key_values = gr.State(None)

    submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history, past_key_values],
                    [chatbot, history, past_key_values], show_progress=True, api_name="predict")
    submitBtn.click(reset_user_input, [], [user_input])

    emptyBtn.click(reset_state, outputs=[chatbot, history, past_key_values], show_progress=True)

    with gr.Accordion("For Translation API", open=False):
        input_text = gr.Text()
        tr_btn = gr.Button("Go", variant="primary")
    tr_btn.click(trans_api, [input_text, max_length, top_p, temperature], [], show_progress=True, api_name="tr")
        
# demo.queue().launch(share=False, inbrowser=True)
# demo.queue().launch(share=True, inbrowser=True, debug=True)

demo.queue().launch(debug=True)