'''
crop
for torch tensor
Given image, bbox(center, bboxsize)
return: cropped image, tform(used for transform the keypoint accordingly)

only support crop to squared images
'''
import torch
from kornia.geometry.transform.imgwarp import (warp_perspective,
                                               get_perspective_transform,
                                               warp_affine)


def points2bbox(points, points_scale=None):
    if points_scale:
        assert points_scale[0] == points_scale[1]
        points = points.clone()
        points[:, :, :2] = (points[:, :, :2] * 0.5 + 0.5) * points_scale[0]
    min_coords, _ = torch.min(points, dim=1)
    xmin, ymin = min_coords[:, 0], min_coords[:, 1]
    max_coords, _ = torch.max(points, dim=1)
    xmax, ymax = max_coords[:, 0], max_coords[:, 1]
    center = torch.stack([xmax + xmin, ymax + ymin], dim=-1) * 0.5

    width = (xmax - xmin)
    height = (ymax - ymin)
    # Convert the bounding box to a square box
    size = torch.max(width, height).unsqueeze(-1)
    return center, size


def augment_bbox(center, bbox_size, scale=[1.0, 1.0], trans_scale=0.):
    batch_size = center.shape[0]
    trans_scale = (torch.rand([batch_size, 2], device=center.device) * 2. -
                   1.) * trans_scale
    center = center + trans_scale * bbox_size  # 0.5
    scale = torch.rand([batch_size, 1], device=center.device) * \
        (scale[1] - scale[0]) + scale[0]
    size = bbox_size * scale
    return center, size


def crop_tensor(image,
                center,
                bbox_size,
                crop_size,
                interpolation='bilinear',
                align_corners=False):
    ''' for batch image
    Args:
        image (torch.Tensor): the reference tensor of shape BXHxWXC.
        center: [bz, 2]
        bboxsize: [bz, 1]
        crop_size;
        interpolation (str): Interpolation flag. Default: 'bilinear'.
        align_corners (bool): mode for grid_generation. Default: False. See
          https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.interpolate for details
    Returns:
        cropped_image
        tform
    '''
    dtype = image.dtype
    device = image.device
    batch_size = image.shape[0]
    # points: top-left, top-right, bottom-right, bottom-left
    src_pts = torch.zeros([4, 2], dtype=dtype,
                          device=device).unsqueeze(0).expand(
                              batch_size, -1, -1).contiguous()

    src_pts[:, 0, :] = center - bbox_size * 0.5  # / (self.crop_size - 1)
    src_pts[:, 1, 0] = center[:, 0] + bbox_size[:, 0] * 0.5
    src_pts[:, 1, 1] = center[:, 1] - bbox_size[:, 0] * 0.5
    src_pts[:, 2, :] = center + bbox_size * 0.5
    src_pts[:, 3, 0] = center[:, 0] - bbox_size[:, 0] * 0.5
    src_pts[:, 3, 1] = center[:, 1] + bbox_size[:, 0] * 0.5

    DST_PTS = torch.tensor([[
        [0, 0],
        [crop_size - 1, 0],
        [crop_size - 1, crop_size - 1],
        [0, crop_size - 1],
    ]],
                           dtype=dtype,
                           device=device).expand(batch_size, -1, -1)
    # estimate transformation between points
    dst_trans_src = get_perspective_transform(src_pts, DST_PTS)
    # simulate broadcasting
    # dst_trans_src = dst_trans_src.expand(batch_size, -1, -1)

    # warp images
    cropped_image = warp_affine(image,
                                dst_trans_src[:, :2, :],
                                (crop_size, crop_size),
                                mode=interpolation,
                                align_corners=align_corners)

    tform = torch.transpose(dst_trans_src, 2, 1)
    # tform = torch.inverse(dst_trans_src)
    return cropped_image, tform


class Cropper(object):

    def __init__(self, crop_size, scale=[1, 1], trans_scale=0.):
        self.crop_size = crop_size
        self.scale = scale
        self.trans_scale = trans_scale

    def crop(self, image, points, points_scale=None):
        # points to bbox
        center, bbox_size = points2bbox(points.clone(), points_scale)
        # argument bbox. TODO: add rotation?
        center, bbox_size = augment_bbox(center,
                                         bbox_size,
                                         scale=self.scale,
                                         trans_scale=self.trans_scale)
        # crop
        cropped_image, tform = crop_tensor(image, center, bbox_size,
                                           self.crop_size)
        return cropped_image, tform

    def transform_points(self,
                         points,
                         tform,
                         points_scale=None,
                         normalize=True):
        points_2d = points[:, :, :2]

        #'input points must use original range'
        if points_scale:
            assert points_scale[0] == points_scale[1]
            points_2d = (points_2d * 0.5 + 0.5) * points_scale[0]

        batch_size, n_points, _ = points.shape
        trans_points_2d = torch.bmm(
            torch.cat([
                points_2d,
                torch.ones([batch_size, n_points, 1],
                           device=points.device,
                           dtype=points.dtype)
            ],
                      dim=-1), tform)
        trans_points = torch.cat([trans_points_2d[:, :, :2], points[:, :, 2:]],
                                 dim=-1)
        if normalize:
            trans_points[:, :, :2] = trans_points[:, :, :2] / \
                self.crop_size*2 - 1
        return trans_points


def transform_points(points, tform, points_scale=None):
    points_2d = points[:, :, :2]

    #'input points must use original range'
    if points_scale:
        assert points_scale[0] == points_scale[1]
        points_2d = (points_2d * 0.5 + 0.5) * points_scale[0]
    # import ipdb; ipdb.set_trace()

    batch_size, n_points, _ = points.shape
    trans_points_2d = torch.bmm(
        torch.cat([
            points_2d,
            torch.ones([batch_size, n_points, 1],
                       device=points.device,
                       dtype=points.dtype)
        ],
                  dim=-1), tform)
    trans_points = torch.cat([trans_points_2d[:, :, :2], points[:, :, 2:]],
                             dim=-1)
    return trans_points