import gradio as gr
import spaces
import time
import torch
from PIL import Image
from transformers import AutoProcessor, AutoModelForVision2Seq
from transformers.image_utils import load_image
from typing import List
processor = AutoProcessor.from_pretrained("TIGER-Lab/Mantis-8B-Idefics2")
model = AutoModelForVision2Seq.from_pretrained("TIGER-Lab/Mantis-8B-Idefics2", torch_dtype=torch.bfloat16)

@spaces.GPU
def generate_stream(text:str, images:List[Image.Image], history: List[dict], **kwargs):
    global processor, model
    model.to("cuda")
    if not images:
        images = None
        
    prompt = processor.apply_chat_template(history, add_generation_prompt=True)
    print("Prompt: ")
    print(prompt)
    print("Images: ")
    print(images)
    inputs = processor(text=prompt, images=images, return_tensors="pt")
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
    from transformers import TextIteratorStreamer
    from threading import Thread
    streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
    kwargs["streamer"] = streamer
    inputs.update(kwargs)
    thread = Thread(target=model.generate, kwargs=inputs)
    thread.start()
    output = ""
    for _output in streamer:
        output += _output
        yield output

def enable_next_image(uploaded_images, image):
    uploaded_images.append(image)
    return uploaded_images, gr.MultimodalTextbox(value=None, interactive=False)

def add_message(history, message):
    if message["files"]:
        for file in message["files"]:
            history.append([(file,), None])
    if message["text"]:
        history.append([message["text"], None])
    return history, gr.MultimodalTextbox(value=None)

def print_like_dislike(x: gr.LikeData):
    print(x.index, x.value, x.liked)


def get_chat_images(history):
    images = []
    for message in history:
        if isinstance(message[0], tuple):
            image = load_image(message[0][0])
            images.append(image)
    return images

def get_chat_history(history):
    
    images = get_chat_images(history)
    messages = []
    cur_image_idx = 0
    for i, message in enumerate(history):
        if isinstance(message[0], str):
            num_images = message[0].count("<image>")
            messages.append(
                {
                    "role": "user",
                    "content": []
                }
            )
            print(num_images, cur_image_idx, len(images))
            assert num_images + cur_image_idx <= len(images), f"Number of images uploaded is less than the number of <image> placeholders in the text. Please upload more images."
            if num_images > 0:
                split_text = message[0].split("<image>")
                if split_text[0].strip():
                    messages[-1]["content"].append({"type": "text", "text": split_text[0].strip()})
                for idx in range(num_images):
                    messages[-1]["content"].append({"type": "image"})
                    if split_text[idx + 1].strip():
                        messages[-1]["content"].append({"type": "text", "text": split_text[idx + 1].strip()})
            else:
                messages[-1]["content"].append({"type": "text", "text": message[0]})
            if message[1]:
                messages.append(
                    {
                        "role": "assistant",
                        "content": [{"type": "text", "text": message[1]}]
                    }
                )
        elif isinstance(message[0], tuple):
            pass
    return messages, images


def bot(history):
    cur_messages = {"text": "", "images": []}
    for message in history[::-1]:
        if message[1]:
            break
        if isinstance(message[0], str):
            cur_messages["text"] = message[0] + " " + cur_messages["text"]
        elif isinstance(message[0], tuple):
            cur_messages["images"].extend(message[0])
    cur_messages["text"] = cur_messages["text"].strip()
    cur_messages["images"] = cur_messages["images"][::-1]
    if not cur_messages["text"]:
        raise gr.Error("Please enter a message")
    if cur_messages['text'].count("<image>") < len(cur_messages['images']):
        gr.Warning("The number of images uploaded is more than the number of <image> placeholders in the text. Will automatically prepend <image> to the text.")
        cur_messages['text'] = "<image> "* (len(cur_messages['images']) - cur_messages['text'].count("<image>")) + cur_messages['text']
        history[-1][0] = cur_messages["text"]
    if cur_messages['text'].count("<image>") > len(cur_messages['images']):
        gr.Warning("The number of images uploaded is less than the number of <image> placeholders in the text. Will automatically remove extra <image> placeholders from the text.")
        cur_messages['text'] = cur_messages['text'][::-1].replace("<image>"[::-1], "", cur_messages['text'].count("<image>") - len(cur_messages['images']))[::-1]
        history[-1][0] = cur_messages["text"]
    
    chat_history, chat_images = get_chat_history(history)
    
    generation_kwargs = {
        "max_new_tokens": 4096,
        "num_beams": 1,
        "do_sample": False
    }
    
    response = generate_stream(None, chat_images, chat_history, **generation_kwargs) 
    for _output in response:
        history[-1][1] = _output
        time.sleep(0.05)
        yield history


        
def build_demo():
    with gr.Blocks() as demo:
        
        gr.Markdown(""" # Mantis
Mantis is a multimodal conversational AI model that can chat with users about images and text. It's optimized for multi-image reasoning, where inverleaved text and images can be used to generate responses.

### [Paper](https://arxiv.org/abs/2405.01483) | [Github](https://github.com/TIGER-AI-Lab/Mantis) | [Models](https://huggingface.co/collections/TIGER-Lab/mantis-6619b0834594c878cdb1d6e4) | [Dataset](https://huggingface.co/datasets/TIGER-Lab/Mantis-Instruct) | [Website](https://tiger-ai-lab.github.io/Mantis/)            
        """)
        
        gr.Markdown("""## Chat with Mantis
        Mantis supports interleaved text-image input format, where you can simply use the placeholder `<image>` to indicate the position of uploaded images.
        The model is optimized for multi-image reasoning, while preserving the ability to chat about text and images in a single conversation.
        (The model currently serving is [🤗 TIGER-Lab/Mantis-8B-Idefics2](https://huggingface.co/TIGER-Lab/Mantis-8B-Idefics2))
        """)
        
        chatbot = gr.Chatbot(line_breaks=True)
        chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload images. Please use <image> to indicate the position of uploaded images", show_label=True)
        
        chat_msg = chat_input.submit(add_message, [chatbot, chat_input], [chatbot, chat_input])
        
        """
        with gr.Accordion(label='Advanced options', open=False):
            temperature = gr.Slider(
                label='Temperature',
                minimum=0.1,
                maximum=2.0,
                step=0.1,
                value=0.2,
                interactive=True
            )
            top_p = gr.Slider(
                label='Top-p',
                minimum=0.05,
                maximum=1.0,
                step=0.05,
                value=1.0,
                interactive=True
            )
        """

        bot_msg = chat_msg.success(bot, chatbot, chatbot, api_name="bot_response")
        
        chatbot.like(print_like_dislike, None, None)

        with gr.Row():
            send_button = gr.Button("Send")
            clear_button = gr.ClearButton([chatbot, chat_input])

        send_button.click(
            add_message, [chatbot, chat_input], [chatbot, chat_input]
        ).then(
            bot, chatbot, chatbot, api_name="bot_response"
        )
        
        gr.Examples(
            examples=[
                {
                    "text": "<image> <image> <image> Which image shows a different mood of character from the others?",
                    "files": ["./examples/image12.jpg", "./examples/image13.jpg", "./examples/image14.jpg"]
                },
                {
                    "text": "<image> <image> What's the difference between these two images? Please describe as much as you can.", 
                    "files": ["./examples/image1.jpg", "./examples/image2.jpg"]
                },
                {
                    "text": "<image> <image> Which image shows an older dog?",
                    "files": ["./examples/image8.jpg", "./examples/image9.jpg"]   
                },
                {
                    "text": "Write a description for the given image sequence in a single paragraph, what is happening in this episode?", 
                    "files": ["./examples/image3.jpg", "./examples/image4.jpg", "./examples/image5.jpg", "./examples/image6.jpg", "./examples/image7.jpg"]
                },
                {
                    "text": "<image> <image> How many dices are there in image 1 and image 2 respectively?",
                    "files": ["./examples/image10.jpg", "./examples/image15.jpg"]
                },
            ],
            inputs=[chat_input],
        )        
        
        gr.Markdown("""
## Citation
```
@article{jiang2024mantis,
  title={MANTIS: Interleaved Multi-Image Instruction Tuning},
  author={Jiang, Dongfu and He, Xuan and Zeng, Huaye and Wei, Con and Ku, Max and Liu, Qian and Chen, Wenhu},
  journal={arXiv preprint arXiv:2405.01483},
  year={2024}
}
```""")
    return demo    
    

if __name__ == "__main__":
    demo = build_demo()
    demo.launch()