import torch
from basicsr.utils import img2tensor, tensor2img
from pytorch_lightning import seed_everything
from ldm.models.diffusion.plms import PLMSSampler
from ldm.modules.encoders.adapter import Adapter, Adapter_light, StyleAdapter
from ldm.util import instantiate_from_config
from ldm.modules.structure_condition.model_edge import pidinet
from ldm.modules.structure_condition.model_seg import seger, Colorize
from ldm.modules.structure_condition.midas.api import MiDaSInference
import gradio as gr
from omegaconf import OmegaConf
import mmcv
from mmdet.apis import inference_detector, init_detector
from mmpose.apis import (inference_top_down_pose_model, init_pose_model, process_mmdet_results, vis_pose_result)
import os
import cv2
import numpy as np
import torch.nn.functional as F
from transformers import CLIPProcessor, CLIPVisionModel
from PIL import Image


def preprocessing(image, device):
    # Resize
    scale = 640 / max(image.shape[:2])
    image = cv2.resize(image, dsize=None, fx=scale, fy=scale)
    raw_image = image.astype(np.uint8)

    # Subtract mean values
    image = image.astype(np.float32)
    image -= np.array(
        [
            float(104.008),
            float(116.669),
            float(122.675),
        ]
    )

    # Convert to torch.Tensor and add "batch" axis
    image = torch.from_numpy(image.transpose(2, 0, 1)).float().unsqueeze(0)
    image = image.to(device)

    return image, raw_image


def imshow_keypoints(img,
                     pose_result,
                     skeleton=None,
                     kpt_score_thr=0.1,
                     pose_kpt_color=None,
                     pose_link_color=None,
                     radius=4,
                     thickness=1):
    """Draw keypoints and links on an image.

    Args:
            img (ndarry): The image to draw poses on.
            pose_result (list[kpts]): The poses to draw. Each element kpts is
                a set of K keypoints as an Kx3 numpy.ndarray, where each
                keypoint is represented as x, y, score.
            kpt_score_thr (float, optional): Minimum score of keypoints
                to be shown. Default: 0.3.
            pose_kpt_color (np.array[Nx3]`): Color of N keypoints. If None,
                the keypoint will not be drawn.
            pose_link_color (np.array[Mx3]): Color of M links. If None, the
                links will not be drawn.
            thickness (int): Thickness of lines.
    """

    img_h, img_w, _ = img.shape
    img = np.zeros(img.shape)

    for idx, kpts in enumerate(pose_result):
        if idx > 1:
            continue
        kpts = kpts['keypoints']
        kpts = np.array(kpts, copy=False)

        # draw each point on image
        if pose_kpt_color is not None:
            assert len(pose_kpt_color) == len(kpts)

            for kid, kpt in enumerate(kpts):
                x_coord, y_coord, kpt_score = int(kpt[0]), int(kpt[1]), kpt[2]

                if kpt_score < kpt_score_thr or pose_kpt_color[kid] is None:
                    # skip the point that should not be drawn
                    continue

                color = tuple(int(c) for c in pose_kpt_color[kid])
                cv2.circle(img, (int(x_coord), int(y_coord)), radius, color, -1)

        # draw links
        if skeleton is not None and pose_link_color is not None:
            assert len(pose_link_color) == len(skeleton)

            for sk_id, sk in enumerate(skeleton):
                pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1]))
                pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1]))

                if (pos1[0] <= 0 or pos1[0] >= img_w or pos1[1] <= 0 or pos1[1] >= img_h or pos2[0] <= 0
                        or pos2[0] >= img_w or pos2[1] <= 0 or pos2[1] >= img_h or kpts[sk[0], 2] < kpt_score_thr
                        or kpts[sk[1], 2] < kpt_score_thr or pose_link_color[sk_id] is None):
                    # skip the link that should not be drawn
                    continue
                color = tuple(int(c) for c in pose_link_color[sk_id])
                cv2.line(img, pos1, pos2, color, thickness=thickness)

    return img


def load_model_from_config(config, ckpt, verbose=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")
    if "state_dict" in pl_sd:
        sd = pl_sd["state_dict"]
    else:
        sd = pl_sd
    model = instantiate_from_config(config.model)
    _, _ = model.load_state_dict(sd, strict=False)

    model.cuda()
    model.eval()
    return model


class Model_all:
    def __init__(self, device='cpu'):
        # common part
        self.device = device
        self.config = OmegaConf.load("configs/stable-diffusion/app.yaml")
        self.config.model.params.cond_stage_config.params.device = device
        self.base_model = load_model_from_config(self.config, "models/sd-v1-4.ckpt").to(device)
        self.current_base = 'sd-v1-4.ckpt'
        self.sampler = PLMSSampler(self.base_model)

        # sketch part
        self.model_canny = Adapter(channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
                                    use_conv=False).to(device)
        self.model_canny.load_state_dict(torch.load("models/t2iadapter_canny_sd14v1.pth", map_location=device))
        self.model_sketch = Adapter(channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
                                    use_conv=False).to(device)
        self.model_sketch.load_state_dict(torch.load("models/t2iadapter_sketch_sd14v1.pth", map_location=device))
        self.model_edge = pidinet().to(device)
        self.model_edge.load_state_dict({k.replace('module.', ''): v for k, v in
                                         torch.load('models/table5_pidinet.pth', map_location=device)[
                                             'state_dict'].items()})

        # segmentation part
        self.model_seger = seger().to(device)
        self.model_seger.eval()
        self.coler = Colorize(n=182)
        self.model_seg = Adapter(cin=int(3 * 64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
                                 use_conv=False).to(device)
        self.model_seg.load_state_dict(torch.load("models/t2iadapter_seg_sd14v1.pth", map_location=device))

        # depth part
        self.depth_model = MiDaSInference(model_type='dpt_hybrid').to(device)
        self.model_depth = Adapter(cin=3 * 64, channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
                                   use_conv=False).to(device)
        self.model_depth.load_state_dict(torch.load("models/t2iadapter_depth_sd14v1.pth", map_location=device))

        # keypose part
        self.model_pose = Adapter(cin=int(3 * 64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
                                  use_conv=False).to(device)
        self.model_pose.load_state_dict(torch.load("models/t2iadapter_keypose_sd14v1.pth", map_location=device))

        # openpose part
        self.model_openpose = Adapter(cin=int(3 * 64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
                                  use_conv=False).to(device)
        self.model_openpose.load_state_dict(torch.load("models/t2iadapter_openpose_sd14v1.pth", map_location=device))

        # color part
        self.model_color = Adapter_light(cin=int(3 * 64), channels=[320, 640, 1280, 1280], nums_rb=4).to(device)
        self.model_color.load_state_dict(torch.load("models/t2iadapter_color_sd14v1.pth", map_location=device))

        # style part
        self.model_style = StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8).to(device)
        self.model_style.load_state_dict(torch.load("models/t2iadapter_style_sd14v1.pth", map_location=device))
        self.clip_processor = CLIPProcessor.from_pretrained('openai/clip-vit-large-patch14')
        self.clip_vision_model = CLIPVisionModel.from_pretrained('openai/clip-vit-large-patch14').to(device)

        device = 'cpu'
        ## mmpose
        det_config = 'models/faster_rcnn_r50_fpn_coco.py'
        det_checkpoint = 'models/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'
        pose_config = 'models/hrnet_w48_coco_256x192.py'
        pose_checkpoint = 'models/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth'
        self.det_cat_id = 1
        self.bbox_thr = 0.2
        ## detector
        det_config_mmcv = mmcv.Config.fromfile(det_config)
        self.det_model = init_detector(det_config_mmcv, det_checkpoint, device=device)
        pose_config_mmcv = mmcv.Config.fromfile(pose_config)
        self.pose_model = init_pose_model(pose_config_mmcv, pose_checkpoint, device=device)
        ## color
        self.skeleton = [[15, 13], [13, 11], [16, 14], [14, 12], [11, 12], [5, 11], [6, 12], [5, 6], [5, 7], [6, 8],
                         [7, 9], [8, 10],
                         [1, 2], [0, 1], [0, 2], [1, 3], [2, 4], [3, 5], [4, 6]]
        self.pose_kpt_color = [[51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255],
                               [0, 255, 0],
                               [255, 128, 0], [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0], [0, 255, 0],
                               [255, 128, 0],
                               [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0]]
        self.pose_link_color = [[0, 255, 0], [0, 255, 0], [255, 128, 0], [255, 128, 0],
                                [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [0, 255, 0],
                                [255, 128, 0],
                                [0, 255, 0], [255, 128, 0], [51, 153, 255], [51, 153, 255], [51, 153, 255],
                                [51, 153, 255],
                                [51, 153, 255], [51, 153, 255], [51, 153, 255]]

    def load_vae(self):
        vae_sd = torch.load(os.path.join('models', 'anything-v4.0.vae.pt'), map_location="cuda")
        sd = vae_sd["state_dict"]
        self.base_model.first_stage_model.load_state_dict(sd, strict=False)

    @torch.no_grad()
    def process_sketch(self, input_img, type_in, color_back, prompt, neg_prompt, pos_prompt, fix_sample, scale,
                       con_strength, base_model):
        if self.current_base != base_model:
            ckpt = os.path.join("models", base_model)
            pl_sd = torch.load(ckpt, map_location="cuda")
            if "state_dict" in pl_sd:
                sd = pl_sd["state_dict"]
            else:
                sd = pl_sd
            self.base_model.load_state_dict(sd, strict=False)
            self.current_base = base_model
            if 'anything' in base_model.lower():
                self.load_vae()

        con_strength = int((1 - con_strength) * 50)
        if fix_sample == 'True':
            seed_everything(42)
        im = cv2.resize(input_img, (512, 512))

        if type_in == 'Sketch':
            if color_back == 'White':
                im = 255 - im
            im_edge = im.copy()
            im = img2tensor(im)[0].unsqueeze(0).unsqueeze(0) / 255.
            im = im > 0.5
            im = im.float()
        elif type_in == 'Image':
            im = img2tensor(im).unsqueeze(0) / 255.
            im = self.model_edge(im.to(self.device))[-1]
            im = im > 0.5
            im = im.float()
            im_edge = tensor2img(im)

        # extract condition features
        c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
        nc = self.base_model.get_learned_conditioning([neg_prompt])
        features_adapter = self.model_sketch(im.to(self.device))
        shape = [4, 64, 64]

        # sampling
        samples_ddim, _ = self.sampler.sample(S=50,
                                              conditioning=c,
                                              batch_size=1,
                                              shape=shape,
                                              verbose=False,
                                              unconditional_guidance_scale=scale,
                                              unconditional_conditioning=nc,
                                              eta=0.0,
                                              x_T=None,
                                              features_adapter1=features_adapter,
                                              mode='sketch',
                                              con_strength=con_strength)

        x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
        x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
        x_samples_ddim = x_samples_ddim.to('cpu')
        x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
        x_samples_ddim = 255. * x_samples_ddim
        x_samples_ddim = x_samples_ddim.astype(np.uint8)

        return [im_edge, x_samples_ddim]
    
    @torch.no_grad()
    def process_canny(self, input_img, type_in, color_back, prompt, neg_prompt, pos_prompt, fix_sample, scale,
                       con_strength, base_model):
        if self.current_base != base_model:
            ckpt = os.path.join("models", base_model)
            pl_sd = torch.load(ckpt, map_location="cuda")
            if "state_dict" in pl_sd:
                sd = pl_sd["state_dict"]
            else:
                sd = pl_sd
            self.base_model.load_state_dict(sd, strict=False)
            self.current_base = base_model
            if 'anything' in base_model.lower():
                self.load_vae()

        con_strength = int((1 - con_strength) * 50)
        if fix_sample == 'True':
            seed_everything(42)
        im = cv2.resize(input_img, (512, 512))

        if type_in == 'Canny':
            if color_back == 'White':
                im = 255 - im
            im_edge = im.copy()
            im = img2tensor(im)[0].unsqueeze(0).unsqueeze(0) / 255.
        elif type_in == 'Image':
            im = cv2.Canny(im,100,200)
            im = img2tensor(im[..., None], bgr2rgb=True, float32=True).unsqueeze(0) / 255.
            im_edge = tensor2img(im)

        # extract condition features
        c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
        nc = self.base_model.get_learned_conditioning([neg_prompt])
        features_adapter = self.model_canny(im.to(self.device))
        shape = [4, 64, 64]

        # sampling
        samples_ddim, _ = self.sampler.sample(S=50,
                                              conditioning=c,
                                              batch_size=1,
                                              shape=shape,
                                              verbose=False,
                                              unconditional_guidance_scale=scale,
                                              unconditional_conditioning=nc,
                                              eta=0.0,
                                              x_T=None,
                                              features_adapter1=features_adapter,
                                              mode='sketch',
                                              con_strength=con_strength)

        x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
        x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
        x_samples_ddim = x_samples_ddim.to('cpu')
        x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
        x_samples_ddim = 255. * x_samples_ddim
        x_samples_ddim = x_samples_ddim.astype(np.uint8)

        return [im_edge, x_samples_ddim]
    
    @torch.no_grad()
    def process_color_sketch(self, input_img_sketch, input_img_color, type_in, type_in_color, w_sketch, w_color, color_back, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
        if self.current_base != base_model:
            ckpt = os.path.join("models", base_model)
            pl_sd = torch.load(ckpt, map_location="cuda")
            if "state_dict" in pl_sd:
                sd = pl_sd["state_dict"]
            else:
                sd = pl_sd
            self.base_model.load_state_dict(sd, strict=False)
            self.current_base = base_model
            if 'anything' in base_model.lower():
                self.load_vae()

        con_strength = int((1 - con_strength) * 50)
        if fix_sample == 'True':
            seed_everything(42)
        im = cv2.resize(input_img_sketch, (512, 512))

        if type_in == 'Sketch':
            if color_back == 'White':
                im = 255 - im
            im_edge = im.copy()
            im = img2tensor(im)[0].unsqueeze(0).unsqueeze(0) / 255.
            im = im > 0.5
            im = im.float()
        elif type_in == 'Image':
            im = img2tensor(im).unsqueeze(0) / 255.
            im = self.model_edge(im.to(self.device))[-1]#.cuda()
            im = im > 0.5
            im = im.float()
            im_edge = tensor2img(im)
        if type_in_color == 'Image':
            input_img_color = cv2.resize(input_img_color,(512//64, 512//64), interpolation=cv2.INTER_CUBIC)  
            input_img_color = cv2.resize(input_img_color,(512,512), interpolation=cv2.INTER_NEAREST)
        else:
            input_img_color = cv2.resize(input_img_color, (512, 512))
        im_color = input_img_color.copy()
        im_color_tensor = img2tensor(input_img_color, bgr2rgb=False).unsqueeze(0) / 255.

        # extract condition features
        c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
        nc = self.base_model.get_learned_conditioning([neg_prompt])
        features_adapter_sketch = self.model_sketch(im.to(self.device))
        features_adapter_color = self.model_color(im_color_tensor.to(self.device))
        features_adapter = [fs*w_sketch+fc*w_color for fs, fc in zip(features_adapter_sketch,features_adapter_color)]
        shape = [4, 64, 64]

        # sampling
        samples_ddim, _ = self.sampler.sample(S=50,
                                              conditioning=c,
                                              batch_size=1,
                                              shape=shape,
                                              verbose=False,
                                              unconditional_guidance_scale=scale,
                                              unconditional_conditioning=nc,
                                              eta=0.0,
                                              x_T=None,
                                              features_adapter1=features_adapter,
                                              mode='sketch',
                                              con_strength=con_strength)

        x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
        x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
        x_samples_ddim = x_samples_ddim.to('cpu')
        x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
        x_samples_ddim = 255. * x_samples_ddim
        x_samples_ddim = x_samples_ddim.astype(np.uint8)

        return [im_edge, im_color, x_samples_ddim]
    
    @torch.no_grad()
    def process_style_sketch(self, input_img_sketch, input_img_style, type_in, color_back, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
        if self.current_base != base_model:
            ckpt = os.path.join("models", base_model)
            pl_sd = torch.load(ckpt, map_location="cuda")
            if "state_dict" in pl_sd:
                sd = pl_sd["state_dict"]
            else:
                sd = pl_sd
            self.base_model.load_state_dict(sd, strict=False)
            self.current_base = base_model
            if 'anything' in base_model.lower():
                self.load_vae()

        con_strength = int((1 - con_strength) * 50)
        if fix_sample == 'True':
            seed_everything(42)
        im = cv2.resize(input_img_sketch, (512, 512))

        if type_in == 'Sketch':
            if color_back == 'White':
                im = 255 - im
            im_edge = im.copy()
            im = img2tensor(im)[0].unsqueeze(0).unsqueeze(0) / 255.
            im = im > 0.5
            im = im.float()
        elif type_in == 'Image':
            im = img2tensor(im).unsqueeze(0) / 255.
            im = self.model_edge(im.to(self.device))[-1]#.cuda()
            im = im > 0.5
            im = im.float()
            im_edge = tensor2img(im)
        
        style = Image.fromarray(input_img_style)
        style_for_clip = self.clip_processor(images=style, return_tensors="pt")['pixel_values']
        style_feat = self.clip_vision_model(style_for_clip.to(self.device))['last_hidden_state']
        style_feat = self.model_style(style_feat)

        # extract condition features
        c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
        nc = self.base_model.get_learned_conditioning([neg_prompt])
        features_adapter = self.model_sketch(im.to(self.device))
        shape = [4, 64, 64]

        # sampling
        samples_ddim, _ = self.sampler.sample(S=50,
                                              conditioning=c,
                                              batch_size=1,
                                              shape=shape,
                                              verbose=False,
                                              unconditional_guidance_scale=scale,
                                              unconditional_conditioning=nc,
                                              eta=0.0,
                                              x_T=None,
                                              features_adapter1=features_adapter,
                                              mode='style',
                                              con_strength=con_strength,
                                              style_feature=style_feat)

        x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
        x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
        x_samples_ddim = x_samples_ddim.to('cpu')
        x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
        x_samples_ddim = 255. * x_samples_ddim
        x_samples_ddim = x_samples_ddim.astype(np.uint8)

        return [im_edge, x_samples_ddim]

    @torch.no_grad()
    def process_color(self, input_img, prompt, neg_prompt, pos_prompt, w_color, type_in_color, fix_sample, scale, con_strength, base_model):
        if self.current_base != base_model:
            ckpt = os.path.join("models", base_model)
            pl_sd = torch.load(ckpt, map_location="cuda")
            if "state_dict" in pl_sd:
                sd = pl_sd["state_dict"]
            else:
                sd = pl_sd
            self.base_model.load_state_dict(sd, strict=False)
            self.current_base = base_model
            if 'anything' in base_model.lower():
                self.load_vae()

        con_strength = int((1 - con_strength) * 50)
        if fix_sample == 'True':
            seed_everything(42)
        if type_in_color == 'Image':
            input_img = cv2.resize(input_img,(512//64, 512//64), interpolation=cv2.INTER_CUBIC)  
            input_img = cv2.resize(input_img,(512,512), interpolation=cv2.INTER_NEAREST)
        else:
            input_img = cv2.resize(input_img, (512, 512))

        im_color = input_img.copy()
        im = img2tensor(input_img, bgr2rgb=False).unsqueeze(0) / 255.

        # extract condition features
        c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
        nc = self.base_model.get_learned_conditioning([neg_prompt])
        features_adapter = self.model_color(im.to(self.device))
        features_adapter = [fi*w_color for fi in features_adapter]
        shape = [4, 64, 64]

        # sampling
        samples_ddim, _ = self.sampler.sample(S=50,
                                              conditioning=c,
                                              batch_size=1,
                                              shape=shape,
                                              verbose=False,
                                              unconditional_guidance_scale=scale,
                                              unconditional_conditioning=nc,
                                              eta=0.0,
                                              x_T=None,
                                              features_adapter1=features_adapter,
                                              mode='sketch',
                                              con_strength=con_strength)

        x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
        x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
        x_samples_ddim = x_samples_ddim.to('cpu')
        x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
        x_samples_ddim = 255. * x_samples_ddim
        x_samples_ddim = x_samples_ddim.astype(np.uint8)

        return [im_color, x_samples_ddim]
    
    @torch.no_grad()
    def process_depth(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale,
                      con_strength, base_model):
        if self.current_base != base_model:
            ckpt = os.path.join("models", base_model)
            pl_sd = torch.load(ckpt, map_location="cuda")
            if "state_dict" in pl_sd:
                sd = pl_sd["state_dict"]
            else:
                sd = pl_sd
            self.base_model.load_state_dict(sd, strict=False)
            self.current_base = base_model
            if 'anything' in base_model.lower():
                self.load_vae()

        con_strength = int((1 - con_strength) * 50)
        if fix_sample == 'True':
            seed_everything(42)
        im = cv2.resize(input_img, (512, 512))

        if type_in == 'Depth':
            im_depth = im.copy()
            depth = img2tensor(im).unsqueeze(0) / 255.
        elif type_in == 'Image':
            im = img2tensor(im).unsqueeze(0) / 127.5 - 1.0
            depth = self.depth_model(im.to(self.device)).repeat(1, 3, 1, 1)
            depth -= torch.min(depth)
            depth /= torch.max(depth)
            im_depth = tensor2img(depth)

        # extract condition features
        c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
        nc = self.base_model.get_learned_conditioning([neg_prompt])
        features_adapter = self.model_depth(depth.to(self.device))
        shape = [4, 64, 64]

        # sampling
        samples_ddim, _ = self.sampler.sample(S=50,
                                              conditioning=c,
                                              batch_size=1,
                                              shape=shape,
                                              verbose=False,
                                              unconditional_guidance_scale=scale,
                                              unconditional_conditioning=nc,
                                              eta=0.0,
                                              x_T=None,
                                              features_adapter1=features_adapter,
                                              mode='sketch',
                                              con_strength=con_strength)

        x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
        x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
        x_samples_ddim = x_samples_ddim.to('cpu')
        x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
        x_samples_ddim = 255. * x_samples_ddim
        x_samples_ddim = x_samples_ddim.astype(np.uint8)

        return [im_depth, x_samples_ddim]

    @torch.no_grad()
    def process_depth_keypose(self, input_img_depth, input_img_keypose, type_in_depth, type_in_keypose, w_depth,
                              w_keypose, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
        if self.current_base != base_model:
            ckpt = os.path.join("models", base_model)
            pl_sd = torch.load(ckpt, map_location="cuda")
            if "state_dict" in pl_sd:
                sd = pl_sd["state_dict"]
            else:
                sd = pl_sd
            self.base_model.load_state_dict(sd, strict=False)
            self.current_base = base_model
            if 'anything' in base_model.lower():
                self.load_vae()

        if fix_sample == 'True':
            seed_everything(42)
        im_depth = cv2.resize(input_img_depth, (512, 512))
        im_keypose = cv2.resize(input_img_keypose, (512, 512))

        # get depth 
        if type_in_depth == 'Depth':
            im_depth_out = im_depth.copy()
            depth = img2tensor(im_depth).unsqueeze(0) / 255.
        elif type_in_depth == 'Image':
            im_depth = img2tensor(im_depth).unsqueeze(0) / 127.5 - 1.0
            depth = self.depth_model(im_depth.to(self.device)).repeat(1, 3, 1, 1)
            depth -= torch.min(depth)
            depth /= torch.max(depth)
            im_depth_out = tensor2img(depth)

        # get keypose
        if type_in_keypose == 'Keypose':
            im_keypose_out = im_keypose.copy()[:,:,::-1]
        elif type_in_keypose == 'Image':
            image = im_keypose.copy()
            im_keypose = img2tensor(im_keypose).unsqueeze(0) / 255.
            mmdet_results = inference_detector(self.det_model, image)
            # keep the person class bounding boxes.
            person_results = process_mmdet_results(mmdet_results, self.det_cat_id)

            # optional
            return_heatmap = False
            dataset = self.pose_model.cfg.data['test']['type']

            # e.g. use ('backbone', ) to return backbone feature
            output_layer_names = None
            pose_results, _ = inference_top_down_pose_model(
                self.pose_model,
                image,
                person_results,
                bbox_thr=self.bbox_thr,
                format='xyxy',
                dataset=dataset,
                dataset_info=None,
                return_heatmap=return_heatmap,
                outputs=output_layer_names)

            # show the results
            im_keypose_out = imshow_keypoints(
                image,
                pose_results,
                skeleton=self.skeleton,
                pose_kpt_color=self.pose_kpt_color,
                pose_link_color=self.pose_link_color,
                radius=2,
                thickness=2)
            im_keypose_out = im_keypose_out.astype(np.uint8)

        # extract condition features
        c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
        nc = self.base_model.get_learned_conditioning([neg_prompt])
        features_adapter_depth = self.model_depth(depth.to(self.device))
        pose = img2tensor(im_keypose_out, bgr2rgb=True, float32=True) / 255.
        pose = pose.unsqueeze(0)
        features_adapter_keypose = self.model_pose(pose.to(self.device))
        features_adapter = [f_d * w_depth + f_k * w_keypose for f_d, f_k in
                            zip(features_adapter_depth, features_adapter_keypose)]
        shape = [4, 64, 64]

        # sampling
        con_strength = int((1 - con_strength) * 50)
        samples_ddim, _ = self.sampler.sample(S=50,
                                              conditioning=c,
                                              batch_size=1,
                                              shape=shape,
                                              verbose=False,
                                              unconditional_guidance_scale=scale,
                                              unconditional_conditioning=nc,
                                              eta=0.0,
                                              x_T=None,
                                              features_adapter1=features_adapter,
                                              mode='sketch',
                                              con_strength=con_strength)

        x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
        x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
        x_samples_ddim = x_samples_ddim.to('cpu')
        x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
        x_samples_ddim = 255. * x_samples_ddim
        x_samples_ddim = x_samples_ddim.astype(np.uint8)

        return [im_depth_out, im_keypose_out[:, :, ::-1], x_samples_ddim]

    @torch.no_grad()
    def process_seg(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale,
                    con_strength, base_model):
        if self.current_base != base_model:
            ckpt = os.path.join("models", base_model)
            pl_sd = torch.load(ckpt, map_location="cuda")
            if "state_dict" in pl_sd:
                sd = pl_sd["state_dict"]
            else:
                sd = pl_sd
            self.base_model.load_state_dict(sd, strict=False)
            self.current_base = base_model
            if 'anything' in base_model.lower():
                self.load_vae()

        con_strength = int((1 - con_strength) * 50)
        if fix_sample == 'True':
            seed_everything(42)
        im = cv2.resize(input_img, (512, 512))

        if type_in == 'Segmentation':
            im_seg = im.copy()
            im = img2tensor(im).unsqueeze(0) / 255.
            labelmap = im.float()
        elif type_in == 'Image':
            im, _ = preprocessing(im, self.device)
            _, _, H, W = im.shape

            # Image -> Probability map
            logits = self.model_seger(im)
            logits = F.interpolate(logits, size=(H, W), mode="bilinear", align_corners=False)
            probs = F.softmax(logits, dim=1)[0]
            probs = probs.cpu().data.numpy()
            labelmap = np.argmax(probs, axis=0)

            labelmap = self.coler(labelmap)
            labelmap = np.transpose(labelmap, (1, 2, 0))
            labelmap = cv2.resize(labelmap, (512, 512))
            labelmap = img2tensor(labelmap, bgr2rgb=False, float32=True) / 255.
            im_seg = tensor2img(labelmap)[:, :, ::-1]
            labelmap = labelmap.unsqueeze(0)

        # extract condition features
        c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
        nc = self.base_model.get_learned_conditioning([neg_prompt])
        features_adapter = self.model_seg(labelmap.to(self.device))
        shape = [4, 64, 64]

        # sampling
        samples_ddim, _ = self.sampler.sample(S=50,
                                              conditioning=c,
                                              batch_size=1,
                                              shape=shape,
                                              verbose=False,
                                              unconditional_guidance_scale=scale,
                                              unconditional_conditioning=nc,
                                              eta=0.0,
                                              x_T=None,
                                              features_adapter1=features_adapter,
                                              mode='sketch',
                                              con_strength=con_strength)

        x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
        x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
        x_samples_ddim = x_samples_ddim.to('cpu')
        x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
        x_samples_ddim = 255. * x_samples_ddim
        x_samples_ddim = x_samples_ddim.astype(np.uint8)

        return [im_seg, x_samples_ddim]

    @torch.no_grad()
    def process_draw(self, input_img, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
        if self.current_base != base_model:
            ckpt = os.path.join("models", base_model)
            pl_sd = torch.load(ckpt, map_location="cuda")
            if "state_dict" in pl_sd:
                sd = pl_sd["state_dict"]
            else:
                sd = pl_sd
            self.base_model.load_state_dict(sd, strict=False)
            self.current_base = base_model
            if 'anything' in base_model.lower():
                self.load_vae()

        con_strength = int((1 - con_strength) * 50)
        if fix_sample == 'True':
            seed_everything(42)
        input_img = input_img['mask']
        c = input_img[:, :, 0:3].astype(np.float32)
        a = input_img[:, :, 3:4].astype(np.float32) / 255.0
        im = c * a + 255.0 * (1.0 - a)
        im = im.clip(0, 255).astype(np.uint8)
        im = cv2.resize(im, (512, 512))

        im_edge = im.copy()
        im = img2tensor(im)[0].unsqueeze(0).unsqueeze(0) / 255.
        im = im > 0.5
        im = im.float()

        # extract condition features
        c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
        nc = self.base_model.get_learned_conditioning([neg_prompt])
        features_adapter = self.model_sketch(im.to(self.device))
        shape = [4, 64, 64]

        # sampling
        samples_ddim, _ = self.sampler.sample(S=50,
                                              conditioning=c,
                                              batch_size=1,
                                              shape=shape,
                                              verbose=False,
                                              unconditional_guidance_scale=scale,
                                              unconditional_conditioning=nc,
                                              eta=0.0,
                                              x_T=None,
                                              features_adapter1=features_adapter,
                                              mode='sketch',
                                              con_strength=con_strength)

        x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
        x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
        x_samples_ddim = x_samples_ddim.to('cpu')
        x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
        x_samples_ddim = 255. * x_samples_ddim
        x_samples_ddim = x_samples_ddim.astype(np.uint8)

        return [im_edge, x_samples_ddim]

    @torch.no_grad()
    def process_keypose(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength,
                        base_model):
        if self.current_base != base_model:
            ckpt = os.path.join("models", base_model)
            pl_sd = torch.load(ckpt, map_location="cuda")
            if "state_dict" in pl_sd:
                sd = pl_sd["state_dict"]
            else:
                sd = pl_sd
            self.base_model.load_state_dict(sd, strict=False)
            self.current_base = base_model
            if 'anything' in base_model.lower():
                self.load_vae()

        con_strength = int((1 - con_strength) * 50)
        if fix_sample == 'True':
            seed_everything(42)
        im = cv2.resize(input_img, (512, 512))

        if type_in == 'Keypose':
            im_pose = im.copy()[:,:,::-1]
        elif type_in == 'Image':
            image = im.copy()
            im = img2tensor(im).unsqueeze(0) / 255.
            mmdet_results = inference_detector(self.det_model, image)
            # keep the person class bounding boxes.
            person_results = process_mmdet_results(mmdet_results, self.det_cat_id)

            # optional
            return_heatmap = False
            dataset = self.pose_model.cfg.data['test']['type']

            # e.g. use ('backbone', ) to return backbone feature
            output_layer_names = None
            pose_results, _ = inference_top_down_pose_model(
                self.pose_model,
                image,
                person_results,
                bbox_thr=self.bbox_thr,
                format='xyxy',
                dataset=dataset,
                dataset_info=None,
                return_heatmap=return_heatmap,
                outputs=output_layer_names)

            # show the results
            im_pose = imshow_keypoints(
                image,
                pose_results,
                skeleton=self.skeleton,
                pose_kpt_color=self.pose_kpt_color,
                pose_link_color=self.pose_link_color,
                radius=2,
                thickness=2)
        # im_pose = cv2.resize(im_pose, (512, 512))

        # extract condition features
        c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
        nc = self.base_model.get_learned_conditioning([neg_prompt])
        pose = img2tensor(im_pose, bgr2rgb=True, float32=True) / 255.
        pose = pose.unsqueeze(0)
        features_adapter = self.model_pose(pose.to(self.device))

        shape = [4, 64, 64]

        # sampling
        samples_ddim, _ = self.sampler.sample(S=50,
                                              conditioning=c,
                                              batch_size=1,
                                              shape=shape,
                                              verbose=False,
                                              unconditional_guidance_scale=scale,
                                              unconditional_conditioning=nc,
                                              eta=0.0,
                                              x_T=None,
                                              features_adapter1=features_adapter,
                                              mode='sketch',
                                              con_strength=con_strength)

        x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
        x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
        x_samples_ddim = x_samples_ddim.to('cpu')
        x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
        x_samples_ddim = 255. * x_samples_ddim
        x_samples_ddim = x_samples_ddim.astype(np.uint8)

        return [im_pose[:, :, ::-1].astype(np.uint8), x_samples_ddim]
    
    @torch.no_grad()
    def process_openpose(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength,
                        base_model):
        if self.current_base != base_model:
            ckpt = os.path.join("models", base_model)
            pl_sd = torch.load(ckpt, map_location="cuda")
            if "state_dict" in pl_sd:
                sd = pl_sd["state_dict"]
            else:
                sd = pl_sd
            self.base_model.load_state_dict(sd, strict=False)
            self.current_base = base_model
            if 'anything' in base_model.lower():
                self.load_vae()

        con_strength = int((1 - con_strength) * 50)
        if fix_sample == 'True':
            seed_everything(42)
        im = cv2.resize(input_img, (512, 512))

        if type_in == 'Openpose':
            im_pose = im.copy()[:,:,::-1]
        elif type_in == 'Image':
            from ldm.modules.structure_condition.openpose.api import OpenposeInference
            model = OpenposeInference()
            keypose = model(im[:,:,::-1])
            im_pose = keypose.copy()

        # extract condition features
        c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
        nc = self.base_model.get_learned_conditioning([neg_prompt])
        pose = img2tensor(im_pose, bgr2rgb=True, float32=True) / 255.
        pose = pose.unsqueeze(0)
        features_adapter = self.model_openpose(pose.to(self.device))

        shape = [4, 64, 64]

        # sampling
        samples_ddim, _ = self.sampler.sample(S=50,
                                              conditioning=c,
                                              batch_size=1,
                                              shape=shape,
                                              verbose=False,
                                              unconditional_guidance_scale=scale,
                                              unconditional_conditioning=nc,
                                              eta=0.0,
                                              x_T=None,
                                              features_adapter1=features_adapter,
                                              mode='sketch',
                                              con_strength=con_strength)

        x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
        x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
        x_samples_ddim = x_samples_ddim.to('cpu')
        x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
        x_samples_ddim = 255. * x_samples_ddim
        x_samples_ddim = x_samples_ddim.astype(np.uint8)

        return [im_pose[:, :, ::-1].astype(np.uint8), x_samples_ddim]


if __name__ == '__main__':
    model = Model_all('cpu')