import os
import tempfile
import uuid
import torch
from PIL import Image
from torchvision import transforms
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
from osediff_sd3 import OSEDiff_SD3_TEST, SD3Euler

# -------------------------------------------------------------------
# Helper: Resize & center-crop to a fixed square
# -------------------------------------------------------------------
def resize_and_center_crop(img: Image.Image, size: int) -> Image.Image:
    w, h = img.size
    scale = size / min(w, h)
    new_w, new_h = int(w * scale), int(h * scale)
    img = img.resize((new_w, new_h), Image.LANCZOS)
    left = (new_w - size) // 2
    top  = (new_h - size) // 2
    return img.crop((left, top, left + size, top + size))


# -------------------------------------------------------------------
# Helper: Generate a single VLM prompt for recursive_multiscale
# -------------------------------------------------------------------
def _generate_vlm_prompt(
    vlm_model: Qwen2_5_VLForConditionalGeneration,
    vlm_processor: AutoProcessor,
    process_vision_info,      # this is your helper that turns “messages” → image_inputs / video_inputs
    prev_pil: Image.Image,    # <– pass PIL instead of path
    zoomed_pil: Image.Image,  # <– pass PIL instead of path
    device: str = "cuda"
) -> str:
    """
    Given two PIL.Image inputs:
      - prev_pil:   the “full” image at the previous recursion.
      - zoomed_pil: the cropped+resized (zoom) image for this step.
    Returns a single “recursive_multiscale” prompt string.
    """

    # (1) System message
    message_text = (
        "The second image is a zoom-in of the first image. "
        "Based on this knowledge, what is in the second image? "
        "Give me a set of words."
    )

    # (2) Build the two-image “chat” payload
    #
    #    Instead of passing a filename, we pass the actual PIL.Image.
    #    The processor’s `process_vision_info` should know how to turn
    #    a message of the form {"type":"image","image": PIL_IMAGE} into tensors.
    messages = [
        {"role": "system", "content": message_text},
        {
            "role": "user",
            "content": [
                {"type": "image", "image": prev_pil},
                {"type": "image", "image": zoomed_pil},
            ],
        },
    ]

    # (3) Now run the “chat” through the VL processor
    #
    #    - `apply_chat_template` will build the tokenized prompt (without running it yet).
    #    - `process_vision_info` should inspect the same `messages` list and return
    #      `image_inputs` and `video_inputs` (tensors) for any attached PIL images.
    text = vlm_processor.apply_chat_template(
        messages, 
        tokenize=False, 
        add_generation_prompt=True
    )
    image_inputs, video_inputs = process_vision_info(messages)

    inputs = vlm_processor(
        text=[text], 
        images=image_inputs, 
        videos=video_inputs, 
        padding=True, 
        return_tensors="pt",
    ).to(device)

    # (4) Generate and decode
    generated = vlm_model.generate(**inputs, max_new_tokens=128)
    trimmed = [
        out_ids[len(in_ids):] 
        for in_ids, out_ids in zip(inputs.input_ids, generated)
    ]
    out_text = vlm_processor.batch_decode(
        trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )[0]

    return out_text.strip()



# -------------------------------------------------------------------
# Main Function: recursive_multiscale_sr (with multiple centers)
# -------------------------------------------------------------------
def recursive_multiscale_sr(
    input_png_path: str,
    upscale: int,
    rec_num: int = 4,
    centers: list[tuple[float, float]] = None,
) -> tuple[list[Image.Image], list[str]]:
    """
    Perform `rec_num` recursive_multiscale super-resolution steps on a single PNG.
    - input_png_path: path to a single .png file on disk.
    - upscale: integer up-scale factor per recursion (e.g. 4).
    - rec_num: how many recursion steps to perform.
    - centers: a list of normalized (x, y) tuples in [0, 1], one per recursion step,
               indicating where to center the low-res crop for each step. The list
               length must equal rec_num. If centers is None, defaults to center=(0.5, 0.5) 
               for all steps.

    Returns a tuple (sr_pil_list, prompt_list), where:
      - sr_pil_list: list of PIL.Image outputs [SR1, SR2, …, SR_rec_num] in order.
      - prompt_list: list of the VLM prompts generated at each recursion.
    """
    ###############################
    # 0. Validate / fill default centers
    ###############################
    if centers is None:
        # Default: use center (0.5, 0.5) for every recursion
        centers = [(0.5, 0.5) for _ in range(rec_num)]
    else:
        if not isinstance(centers, (list, tuple)) or len(centers) != rec_num:
            raise ValueError(
                f"`centers` must be a list of {rec_num} (x,y) tuples, but got length {len(centers)}."
            )

    ###############################
    # 1. Fixed hyper-parameters
    ###############################
    device = "cuda"
    process_size = 512     # same as args.process_size

    # model checkpoint paths (hard-coded to your example)
    LORA_PATH = "ckpt/SR_LoRA/model_20001.pkl"
    VAE_PATH  = "ckpt/SR_VAE/vae_encoder_20001.pt"
    SD3_MODEL = "stabilityai/stable-diffusion-3-medium-diffusers"
    # VLM model name (hard-coded)
    VLM_NAME  = "Qwen/Qwen2.5-VL-3B-Instruct"

    ###############################
    # 2. Build a dummy “args” namespace
    #    to satisfy OSEDiff_SD3_TEST constructor.
    ###############################
    class _Args:
        pass

    args = _Args()
    args.upscale                       = upscale
    args.lora_path                     = LORA_PATH
    args.vae_path                      = VAE_PATH
    args.pretrained_model_name_or_path = SD3_MODEL
    args.merge_and_unload_lora         = False
    args.lora_rank                     = 4
    args.vae_decoder_tiled_size        = 224
    args.vae_encoder_tiled_size        = 1024
    args.latent_tiled_size             = 96
    args.latent_tiled_overlap          = 32
    args.mixed_precision               = "fp16"
    args.efficient_memory              = False
    # (other flags are not used by OSEDiff_SD3_TEST, so we skip them)

    ###############################
    # 3. Load the SD3 SR model (non-efficient)
    ###############################
    # 3.1 Instantiate the underlying SD3-Euler UNet/VAE/text encoders
    sd3 = SD3Euler()
    # move all text encoders + transformer + VAE to CUDA:
    sd3.text_enc_1.to(device)
    sd3.text_enc_2.to(device)
    sd3.text_enc_3.to(device)
    sd3.transformer.to(device, dtype=torch.float32)
    sd3.vae.to(device, dtype=torch.float32)
    # freeze
    for p in (
        sd3.text_enc_1,
        sd3.text_enc_2,
        sd3.text_enc_3,
        sd3.transformer,
        sd3.vae,
    ):
        p.requires_grad_(False)

    # 3.2 Wrap in OSEDiff_SD3_TEST helper:
    model_test = OSEDiff_SD3_TEST(args, sd3)
    # (by default, “model_test(...)” takes (lq_tensor, prompt=str) and returns a list[tensor])

    ###############################
    # 4. Load the VLM (Qwen2.5-VL)
    ###############################
    vlm_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        VLM_NAME,
        torch_dtype="auto",
        device_map="auto"   # immediately dispatches layers onto available GPUs
    )
    vlm_processor = AutoProcessor.from_pretrained(VLM_NAME)

    ###############################
    # 5. Pre-allocate a Temporary Directory
    #    to hold intermediate JPEG/PNG files
    ###############################
    unique_id = uuid.uuid4().hex
    prefix = f"recms_{unique_id}_"

    with tempfile.TemporaryDirectory(prefix=prefix) as td:
        # (we’ll write “prev.png” and “zoom.png” at each step)

        ###############################
        # 6. Prepare the very first “full” image
        ###############################
        # (6.1) Load + center crop → first_image (512×512)
        img0 = Image.open(input_png_path).convert("RGB")
        img0 = resize_and_center_crop(img0, process_size)

        # Note: we no longer need to write “prev.png” to disk. Just keep it in memory.
        prev_pil = img0.copy()

        sr_pil_list: list[Image.Image] = []
        prompt_list:  list[str]        = []

        for rec in range(rec_num):
            # (A) Compute low-res crop window on prev_pil
            w, h = prev_pil.size  # (512×512)
            new_w, new_h = w // upscale, h // upscale

            cx_norm, cy_norm = centers[rec]
            cx = int(cx_norm * w)
            cy = int(cy_norm * h)
            half_w, half_h = new_w // 2, new_h // 2

            left = max(0, min(cx - half_w, w - new_w))
            top  = max(0, min(cy - half_h, h - new_h))
            right, bottom = left + new_w, top + new_h

            cropped = prev_pil.crop((left, top, right, bottom))

            # (B) Upsample that crop back to (512×512)
            zoomed_pil = cropped.resize((w, h), Image.BICUBIC)

            # (C) Generate VLM prompt by passing PILs directly:
            prompt_tag = _generate_vlm_prompt(
                vlm_model=vlm_model,
                vlm_processor=vlm_processor,
                process_vision_info=process_vision_info,
                prev_pil=prev_pil,     # <– PIL
                zoomed_pil=zoomed_pil, # <– PIL
                device=device,
            )

            # (D) Prepare “zoomed_pil” → tensor in [−1, 1]
            to_tensor = transforms.ToTensor()
            lq = to_tensor(zoomed_pil).unsqueeze(0).to(device)  # (1,3,512,512)
            lq = (lq * 2.0) - 1.0

            # (E) Run SR inference
            with torch.no_grad():
                out_tensor = model_test(lq, prompt=prompt_tag)[0]
                out_tensor = out_tensor.clamp(-1.0, 1.0).cpu()
                out_pil = transforms.ToPILImage()((out_tensor * 0.5) + 0.5)

            # (F) Bookkeeping: set prev_pil = out_pil for next iteration
            prev_pil = out_pil

            # (G) Append to results
            sr_pil_list.append(out_pil)
            prompt_list.append(prompt_tag)

        return sr_pil_list, prompt_list