from typing import Optional, Tuple, Union, List

import torch
import torch.utils.checkpoint
from torch import nn
from transformers.utils import (
    logging,
)
from transformers.models.blip_2.configuration_blip_2 import Blip2Config
from transformers.models.blip_2.modeling_blip_2 import Blip2ForConditionalGenerationModelOutput
from transformers import (
    Blip2PreTrainedModel,
    Blip2VisionModel,
    Blip2QFormerModel,
    PreTrainedTokenizer,
    PreTrainedModel,
)


logger = logging.get_logger(__name__)


class ZiyaBlip2ForCausalLM(Blip2PreTrainedModel):
    config_class = Blip2Config
    main_input_name = "pixel_values"
    _keys_to_ignore_on_load_missing = [
        r"language_model",
    ]
    def __init__(self, config: Blip2Config, language_model: PreTrainedModel = None):
        super().__init__(config)

        self.vision_model = Blip2VisionModel(config.vision_config)

        self.query_tokens = nn.Parameter(torch.zeros(
            1, config.num_query_tokens, config.qformer_config.hidden_size))
        self.qformer = Blip2QFormerModel(config.qformer_config)

        self.language_projection = nn.Linear(
            config.qformer_config.hidden_size, config.text_config.hidden_size)
        self.language_model = language_model

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.language_model.get_input_embeddings()

    def set_input_embeddings(self, value):
        self.language_model.set_input_embeddings(value)

    def set_output_embeddings(self, new_embeddings):
        self.language_model.set_output_embeddings(new_embeddings)

    def get_output_embeddings(self) -> nn.Module:
        return self.language_model.get_output_embeddings()

    def get_encoder(self):
        return self.language_model.get_encoder()

    def get_decoder(self):
        return self.language_model.get_decoder()

    def _tie_weights(self):
        if not self.config.use_decoder_only_language_model:
            self.language_model.encoder.embed_tokens = self.language_model.shared
            self.language_model.decoder.embed_tokens = self.language_model.shared

    def _preprocess_accelerate(self):
        r"""
        Some pre-processing hacks to make the model `accelerate` compatible. Check
        https://github.com/huggingface/transformers/pull/21707 for more details.
        """
        hf_device_map = self.hf_device_map

        if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1:
            # warn users about unexpected behavior when using multi-GPU + BLIP-2 + `accelerate`.
            logger.warning(
                "The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
                " in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
                " Please pass a `device_map` that contains `language_model` to remove this warning."
                " Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for",
                " more details on creating a `device_map` for large models.",
            )

        if hasattr(self.language_model, "_hf_hook"):
            self.language_model._hf_hook.io_same_device = True  # For `generate` compatibility

    def forward(
        self,
        pixel_values: torch.FloatTensor,
        input_ids_before_image: torch.FloatTensor,
        input_ids_after_image: torch.FloatTensor,
        labels_after_image: torch.FloatTensor,
        # 因为label不会出现在image之前,所以这里不需要labels_before_image, 按照input_ids_before_image补-100就可以了
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, Blip2ForConditionalGenerationModelOutput]:
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # step 1: forward the images through the vision encoder,
        # to get image embeddings of shape (batch_size, seq_len, hidden_size)
        vision_outputs = self.vision_model(
            pixel_values=pixel_values,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        image_embeds = vision_outputs[0]

        # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
        image_attention_mask = torch.ones(
            image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)

        query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
        query_outputs = self.qformer(
            query_embeds=query_tokens,
            encoder_hidden_states=image_embeds,
            encoder_attention_mask=image_attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        query_output = query_outputs[0]

        # step 2.5 generate the lm input by prompt and output
        language_model_inputs = self.language_projection(query_output)
        language_model_attention_mask = torch.ones(
            language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
        )
        # 确保language_model_inputs的batch
        assert language_model_inputs.shape[0] == input_ids_after_image.shape[0]
        inputs_embeds_before_image = self.language_model.get_input_embeddings()(input_ids_before_image)
        inputs_embeds_after_image = self.language_model.get_input_embeddings()(input_ids_after_image)
        inputs_embeds = torch.cat(
            [
                inputs_embeds_before_image.to(language_model_inputs.device),
                language_model_inputs,
                inputs_embeds_after_image.to(language_model_inputs.device)
            ], dim=1)

        attention_mask_before = torch.ones_like(input_ids_before_image)
        attention_mask_after = torch.ones_like(input_ids_after_image)
        attention_mask = torch.cat(
            [
                attention_mask_before.to(language_model_attention_mask.device),
                language_model_attention_mask,
                attention_mask_after.to(language_model_attention_mask.device)
            ], dim=1
        )
        # labels也需要对应的处理,把前面空缺的-100加进去
        labels = torch.cat(
            [
                torch.tensor(
                    [-100]).expand_as(input_ids_before_image).to(language_model_inputs.device),
                torch.tensor([-100]).expand(query_tokens.shape[:-1]
                                            ).to(language_model_inputs.device),
                labels_after_image,
            ], dim=1
        )

        # step 3: use the language model

        if self.config.use_decoder_only_language_model:
            outputs = self.language_model(
                inputs_embeds=inputs_embeds,
                attention_mask=attention_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                labels=labels,
            )
            loss = outputs.loss if return_dict else outputs[0]
            logits = outputs.logits if return_dict else outputs[1]

        else:
            raise Exception("not impl")

        if not return_dict:
            output = (logits, vision_outputs, query_outputs, outputs)
            return ((loss,) + output) if loss is not None else output

        return Blip2ForConditionalGenerationModelOutput(
            loss=loss,
            logits=logits,
            vision_outputs=vision_outputs,
            qformer_outputs=query_outputs,
            language_model_outputs=outputs,
        )

    def prepare_inputs_for_chat(
        self,
        tokenizer: PreTrainedTokenizer,
        query: str,
        pixel_values: torch.Tensor,
        previous_querys: List[str],
        previous_outputs: List[str],
        max_length: int,
    ):
        # 1. process input_ids
        assert len(previous_querys) == len(previous_outputs)
        device = self.device
        prefix = self.config.prompt_prefix
        human_name = self.config.human_name
        assistant_name = self.config.assistant_name
        input_ids_before_image = tokenizer(
            prefix, return_tensors="pt").input_ids.to(device)
        inputs_ids_after_image = []
        for (p, o) in zip(previous_querys, previous_outputs):
            # {pormpt}\n[答]: {output}\n[问]:
            inputs_ids_after_image += tokenizer(f"{human_name}: {p}\n", add_special_tokens=False).input_ids + \
                tokenizer(f"{assistant_name}: {o}\n", add_special_tokens=False).input_ids

        inputs_ids_after_image += tokenizer(f"{human_name}: {query}\n",
                                            add_special_tokens=False).input_ids + tokenizer(f"{assistant_name} :",
                                            add_special_tokens=False).input_ids
        inputs_ids_after_image = torch.IntTensor([inputs_ids_after_image]).to(device)
        # 2. Prepare embeddings
        pixel_values.to(device)
        image_embeds = self.vision_model(pixel_values, return_dict=True).last_hidden_state
        image_attention_mask = torch.ones(
            image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
        query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
        query_outputs = self.qformer(
            query_embeds=query_tokens,
            encoder_hidden_states=image_embeds,
            encoder_attention_mask=image_attention_mask,
            return_dict=True,
        )
        query_output = query_outputs.last_hidden_state
        language_model_inputs = self.language_projection(query_output)

        # concatenate query embeddings with prompt embeddings
        prefix_inputs_embeds = self.get_input_embeddings()(input_ids_before_image)
        prompt_inputs_embeds = self.get_input_embeddings()(inputs_ids_after_image)
        inputs_embeds = torch.cat([
            prefix_inputs_embeds.to(language_model_inputs.device),
            language_model_inputs,
            prompt_inputs_embeds.to(language_model_inputs.device)], dim=1)

        if inputs_embeds.shape[1] > max_length:
            inputs_embeds = inputs_embeds[:, -max_length:, :]

        input_ids = torch.concat([
            input_ids_before_image,
            torch.tensor([tokenizer.eos_token_id]).expand(
                query_tokens.shape[:-1]).to(language_model_inputs.device),
            inputs_ids_after_image,
        ], dim=1)

        return input_ids, inputs_embeds

    def chat(self,
             tokenizer,
             query: str,
             pixel_values: torch.Tensor,
             previous_querys: List[str],
             previous_outputs: List[str],
             **generate_kwargs,):
        """
        use for generate text by chat-style
        Args:
            tokenizer (PretrainedTokenizer): llama tokenizer
            query (str): current input query
            pixel_values (torch.Tensor): image after image_processor
            prompts (List[str]): chat history
            outputs (List[str]): chat history

        Returns:
            text: generate text
        """
        input_ids, inputs_embeds = self.prepare_inputs_for_chat(
            tokenizer, query, pixel_values, previous_querys, previous_outputs, 2048
        )
        response = self.language_model.generate(
            inputs_embeds=inputs_embeds,
            attention_mask=torch.ones_like(input_ids),
            **generate_kwargs,
        )
        response = tokenizer.decode(response[0], skip_special_tokens=True)
        return response