import time
import torch
import cv2
from PIL import Image, ImageDraw, ImageOps
import numpy as np
from typing import Union
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
import matplotlib.pyplot as plt
import PIL

class BaseSegmenter:
    def __init__(self, device, checkpoint, model_type='vit_h', reuse_feature = True, model=None):
        print(f"Initializing BaseSegmenter to {device}")
        self.device = device
        self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
        self.processor = None
        self.model_type = model_type
        if model is None:
            self.checkpoint = checkpoint
            self.model = sam_model_registry[self.model_type](checkpoint=self.checkpoint)
            self.model.to(device=self.device)
        else:
            self.model = model
        self.reuse_feature = reuse_feature
        self.predictor = SamPredictor(self.model)
        self.mask_generator = SamAutomaticMaskGenerator(self.model)
        self.image_embedding = None
        self.image = None

    
    @torch.no_grad()
    def set_image(self, image: Union[np.ndarray, Image.Image, str]):
        if type(image) == str: # input path
            image = Image.open(image)
            image = np.array(image)
        elif type(image) == Image.Image:
            image = np.array(image)
        self.image = image
        if self.reuse_feature:
            self.predictor.set_image(image)
            self.image_embedding = self.predictor.get_image_embedding()
            print(self.image_embedding.shape)

    
    @torch.no_grad()
    def inference(self, image, control):
        if 'everything' in control['prompt_type']:
            masks = self.mask_generator.generate(image)
            new_masks = np.concatenate([mask["segmentation"][np.newaxis,:] for mask in masks])
            return new_masks
        else:
            if not self.reuse_feature or self.image_embedding is None:
                self.set_image(image)
                self.predictor.set_image(self.image)
            else:
                assert self.image_embedding is not None
                self.predictor.features = self.image_embedding
      
        if 'mutimask_output' in control:
            masks, scores, logits = self.predictor.predict(
                point_coords = np.array(control['input_point']),
                point_labels = np.array(control['input_label']),
                multimask_output = True,
            )
        elif 'input_boxes' in control:
            transformed_boxes = self.predictor.transform.apply_boxes_torch(
                torch.tensor(control["input_boxes"], device=self.predictor.device),
                image.shape[:2]
            )
            masks, _, _ = self.predictor.predict_torch(
                point_coords=None,
                point_labels=None,
                boxes=transformed_boxes,
                multimask_output=False,
            )
            masks = masks.squeeze(1).cpu().numpy()
            
        else:
            input_point = np.array(control['input_point']) if 'click' in control['prompt_type'] else None
            input_label = np.array(control['input_label']) if 'click' in control['prompt_type'] else None
            input_box = np.array(control['input_box']) if 'box' in control['prompt_type'] else None
           
            masks, scores, logits = self.predictor.predict(
                point_coords = input_point,
                point_labels = input_label,
                box = input_box,
                multimask_output = False,
            )
            
            if 0 in control['input_label']:
                mask_input = logits[np.argmax(scores), :, :]
                masks, scores, logits = self.predictor.predict(
                    point_coords=input_point,
                    point_labels=input_label,
                    box = input_box,
                    mask_input=mask_input[None, :, :],
                    multimask_output=False,
                )
  
        return masks

if __name__ == "__main__":
    image_path = 'segmenter/images/truck.jpg'
    prompts = [
        # {
        #     "prompt_type":["click"],
        #     "input_point":[[500, 375]],
        #     "input_label":[1],
        #     "multimask_output":"True",
        # },
        {
            "prompt_type":["click"],
            "input_point":[[1000, 600], [1325, 625]],
            "input_label":[1, 0],
        },
        # {
        #     "prompt_type":["click", "box"],
        #     "input_box":[425, 600, 700, 875],
        #     "input_point":[[575, 750]],
        #     "input_label": [0]
        # },
        # {
        #     "prompt_type":["box"],
        #     "input_boxes": [
        #         [75, 275, 1725, 850],
        #         [425, 600, 700, 875],
        #         [1375, 550, 1650, 800],
        #         [1240, 675, 1400, 750],
        #     ]
        # },
        # {
        #     "prompt_type":["everything"]
        # },
    ]
    
    init_time = time.time()
    segmenter = BaseSegmenter(
        device='cuda',
        # checkpoint='sam_vit_h_4b8939.pth',
        checkpoint='segmenter/sam_vit_h_4b8939.pth',
        model_type='vit_h',
        reuse_feature=True
    )
    print(f'init time: {time.time() - init_time}')
    
    image_path = 'test_img/img2.jpg'
    infer_time = time.time()
    for i, prompt in enumerate(prompts):
        print(f'{prompt["prompt_type"]} mode')
        image = Image.open(image_path)
        segmenter.set_image(np.array(image))
        masks = segmenter.inference(np.array(image), prompt)
        Image.fromarray(masks[0]).save('seg.png')
        print(masks.shape)
        
    print(f'infer time: {time.time() - infer_time}')