import cv2
import numpy as np
import mp_triangles
import time
from PIL import Image

from glibvision.cv2_utils import (
    blend_rgb_images,
    pil_to_bgr_image,
    fill_points,
    crop,
    paste,
)
from mp_utils import (
    get_pixel_cordinate_list,
    extract_landmark,
    get_pixel_cordinate,
    get_normalized_landmarks,
    sort_triangles_by_depth,
    get_landmark_bbox,
)

import numba as nb


@nb.jit(nopython=True, parallel=True)
def blend_rgb_images_numba(image1, image2, mask):
    height, width, _ = image1.shape
    result = np.empty((height, width, 3), dtype=np.float32)

    for i in nb.prange(height):
        for j in range(width):
            alpha = mask[i, j] / 255.0
            for k in range(3):
                result[i, j, k] = (1 - alpha) * image1[i, j, k] + alpha * image2[
                    i, j, k
                ]

    return result.astype(np.uint8)


@nb.jit(nopython=True, parallel=True)
def blend_rgba_images_numba(image1, image2, mask):
    assert (
        image1.shape[2] == image2.shape[2]
    ), f"Input images must be same image1 = {image1.shape[2]} image2 ={image2.shape[2]}"
    channel = image1.shape[2]
    height, width, _ = image1.shape
    result = np.empty((height, width, channel), dtype=np.float32)

    for i in nb.prange(height):
        for j in range(width):
            alpha = mask[i, j] / 255.0
            for k in range(channel):
                result[i, j, k] = (1 - alpha) * image1[i, j, k] + alpha * image2[
                    i, j, k
                ]

    return result.astype(np.uint8)


"""
https://stackoverflow.com/questions/6946653/copying-triangular-image-region-with-pil
This topic give me a idea
"""
"""
bug some hide value make white
"""
debug_affinn = False
min_affin_plus = 0.1


def apply_affine_transformation_to_triangle(src_tri, dst_tri, src_img, dst_img):
    src_tri_np = np.float32(src_tri)
    dst_tri_np = np.float32(dst_tri)

    assert src_tri_np.shape == (3, 2), f"src_tri_np の形状が不正 {src_tri_np.shape}"
    assert dst_tri_np.shape == (3, 2), f"dst_tri_np の形状が不正 {dst_tri_np.shape}"

    # trying avoid same value,or M will broken
    if (src_tri_np[0] == src_tri_np[1]).all():
        src_tri_np[0] += min_affin_plus
    if (src_tri_np[0] == src_tri_np[2]).all():
        src_tri_np[0] += min_affin_plus
    if (src_tri_np[1] == src_tri_np[2]).all():
        src_tri_np[1] += min_affin_plus
    if (src_tri_np[1] == src_tri_np[0]).all():
        src_tri_np[1] += min_affin_plus

    if (
        (src_tri_np[1] == src_tri_np[0]).all()
        or (src_tri_np[1] == src_tri_np[2]).all()
        or (src_tri_np[2] == src_tri_np[0]).all()
    ):
        print("same will white noise happen")
    # 透視変換行列の計算
    M = cv2.getAffineTransform(src_tri_np, dst_tri_np)
    # 画像のサイズ
    h_src, w_src = src_img.shape[:2]
    h_dst, w_dst = dst_img.shape[:2]

    # 元画像から三角形領域を切り抜くマスク生成
    # src_mask = np.zeros((h_src, w_src), dtype=np.uint8)
    # cv2.fillPoly(src_mask, [np.int32(src_tri)], 255)

    # Not 元画像の三角形領域のみをマスクで抽出
    src_triangle = src_img  # cv2.bitwise_and(src_img, src_img, mask=src_mask)

    # 変換行列を使って元画像の三角形領域を目標画像のサイズへ変換

    transformed = cv2.warpAffine(src_triangle, M, (w_dst, h_dst))
    if debug_affinn:
        cv2.imwrite("affin_src.jpg", src_triangle)
        cv2.imwrite("affin_transformed.jpg", transformed)

    # print(f"dst_img={dst_img.shape}")
    # print(f"transformed={transformed.shape}")
    # 変換後のマスクの生成
    dst_mask = np.zeros((h_dst, w_dst), dtype=np.uint8)
    cv2.fillPoly(dst_mask, [np.int32(dst_tri)], 255)

    # 目標画像のマスク領域をクリアするためにデストのインバートマスクを作成
    # dst_mask_inv = cv2.bitwise_not(dst_mask)

    # 目標画像のマスク部分をクリア
    # dst_background = cv2.bitwise_and(dst_img, dst_img, mask=dst_mask_inv)

    # 変換された元画像の三角形部分と目標画像の背景部分を合成
    # dst_img = cv2.add(dst_background, transformed)
    # s = time.time()
    # dst_img = blend_rgb_images(dst_img,transformed,dst_mask)

    use_blend_rgb = False
    if use_blend_rgb:
        if src_img.shape[2] == 3:
            dst_img = blend_rgb_images_numba(dst_img, transformed, dst_mask)
        else:
            dst_img = blend_rgba_images_numba(dst_img, transformed, dst_mask)
    else:
        dst_mask_inv = cv2.bitwise_not(dst_mask)
        transformed = cv2.bitwise_and(transformed, transformed, mask=dst_mask)
        dst_img = cv2.bitwise_and(dst_img, dst_img, mask=dst_mask_inv)
        dst_img = cv2.add(dst_img, transformed)

    # TODO add rgb mode

    # print(f"blend {time.time() -s}")
    if debug_affinn:
        cv2.imwrite("affin_transformed_masked.jpg", transformed)
        cv2.imwrite("affin_dst_mask.jpg", dst_mask)
    return dst_img


from skimage.exposure import match_histograms


def color_match(base_image, cropped_image, color_match_format="RGB"):
    reference = np.array(base_image.convert(color_match_format))
    target = np.array(cropped_image.convert(color_match_format))
    matched = match_histograms(target, reference, channel_axis=-1)

    return Image.fromarray(matched, mode=color_match_format)


def process_landmark_transform(
    image,
    transform_target_image,
    innner_mouth,
    innner_eyes,
    color_matching=False,
    transparent_background=False,
    add_align_mouth=False,
    add_align_eyes=False,
    blur_size=0,
):
    image_h, image_w = image.shape[:2]
    align_h, align_w = transform_target_image.shape[:2]

    mp_image, image_face_landmarker_result = extract_landmark(image)
    image_larndmarks = image_face_landmarker_result.face_landmarks
    image_bbox = get_landmark_bbox(image_larndmarks, image_w, image_h, 16, 16)

    mp_image, align_face_landmarker_result = extract_landmark(transform_target_image)
    align_larndmarks = align_face_landmarker_result.face_landmarks
    align_bbox = get_landmark_bbox(align_larndmarks, align_w, align_h, 16, 16)

    if color_matching:
        image_cropped = crop(image, image_bbox)
        target_cropped = crop(transform_target_image, align_bbox)
        matched = match_histograms(image_cropped, target_cropped, channel_axis=-1)
        paste(image, matched, image_bbox[0], image_bbox[1])

    landmark_points = get_normalized_landmarks(align_larndmarks)

    mesh_triangle_indices = (
        mp_triangles.mesh_triangle_indices.copy()
    )  # using directly sometime share

    # always mix for blur
    mesh_triangle_indices += mp_triangles.INNER_MOUTH

    mesh_triangle_indices += (
        mp_triangles.INNER_LEFT_EYES + mp_triangles.INNER_RIGHT_EYES
    )
    # print(mesh_triangle_indices)
    sort_triangles_by_depth(landmark_points, mesh_triangle_indices)

    # mesh_triangle_indices = mp_triangles.contour_to_triangles(True,draw_updown_contour) + mp_triangles.contour_to_triangles(False,draw_updown_contour)+ mp_triangles.mesh_triangle_indices

    triangle_size = len(mesh_triangle_indices)
    # print(f"triangle_size = {triangle_size},time ={0.1*triangle_size}")
    s = time.time()

    need_transparent_way = transparent_background == True or blur_size > 0
    if need_transparent_way:  # convert Alpha
        transparent_image = np.zeros_like(
            cv2.cvtColor(transform_target_image, cv2.COLOR_BGR2BGRA)
        )
        h, w = transparent_image.shape[:2]
        cv2.rectangle(transparent_image, (0, 0), (w, h), (0, 0, 0, 0), -1)

        applied_image = transparent_image
        image = cv2.cvtColor(image, cv2.COLOR_BGR2BGRA)

    else:
        applied_image = transform_target_image

    for i in range(0, triangle_size):  #
        triangle_indices = mesh_triangle_indices[i]

        image_points = get_pixel_cordinate_list(
            image_larndmarks, triangle_indices, image_w, image_h
        )

        align_points = get_pixel_cordinate_list(
            align_larndmarks, triangle_indices, align_w, align_h
        )
        # print(image_points)
        # print(align_points)
        # fill_points(image,image_points,thickness=3,fill_color=(0,0,0,0))
        # s = time.time()
        # print(f"applied_image={applied_image.shape}")
        applied_image = apply_affine_transformation_to_triangle(
            image_points, align_points, image, applied_image
        )

    # print(f"take time {time.time()-s}")
    if need_transparent_way:
        blur_radius = blur_size
        if blur_radius != 0 and blur_radius % 2 == 0:
            blur_radius += 1

        b, g, r, a = cv2.split(applied_image)
        applied_image = cv2.merge([b, g, r])
        mask = a.copy()
        dilate = blur_radius
        kernel = np.ones((dilate, dilate), np.uint8)
        mask = cv2.erode(mask, kernel, iterations=1)

        if blur_radius > 0:
            blurred_image = cv2.GaussianBlur(
                mask, (blur_radius, blur_radius), 0
            )  # should be odd
        else:
            blurred_image = mask

        if transparent_background:
            # transform_target_image = np.zeros_like(cv2.cvtColor(transform_target_image, cv2.COLOR_BGR2BGRA))
            transform_target_image = cv2.cvtColor(
                transform_target_image, cv2.COLOR_BGR2BGRA
            )
            applied_image = cv2.merge([b, g, r, blurred_image])
        else:
            applied_image = blend_rgb_images(
                transform_target_image, applied_image, blurred_image
            )

    # after mix
    if (
        not innner_mouth
        or not innner_eyes
        or (transparent_background and (add_align_mouth or add_align_eyes))
    ):
        import mp_constants

        dst_mask = np.zeros((align_h, align_w), dtype=np.uint8)
        if not innner_mouth or (transparent_background and add_align_mouth):
            mouth_cordinates = get_pixel_cordinate_list(
                align_larndmarks, mp_constants.LINE_INNER_MOUTH, align_w, align_h
            )
            cv2.fillPoly(dst_mask, [np.int32(mouth_cordinates)], 255)

            if transparent_background and not add_align_mouth:
                cv2.fillPoly(
                    transform_target_image, [np.int32(mouth_cordinates)], [0, 0, 0, 0]
                )

        if not innner_eyes or (transparent_background and add_align_eyes):
            left_eyes_cordinates = get_pixel_cordinate_list(
                align_larndmarks, mp_constants.LINE_LEFT_INNER_EYES, align_w, align_h
            )

            cv2.fillPoly(dst_mask, [np.int32(left_eyes_cordinates)], 255)

            right_eyes_cordinates = get_pixel_cordinate_list(
                align_larndmarks, mp_constants.LINE_RIGHT_INNER_EYES, align_w, align_h
            )
            cv2.fillPoly(dst_mask, [np.int32(right_eyes_cordinates)], 255)

            if transparent_background and not add_align_eyes:
                cv2.fillPoly(
                    transform_target_image,
                    [np.int32(left_eyes_cordinates)],
                    [0, 0, 0, 0],
                )
                cv2.fillPoly(
                    transform_target_image,
                    [np.int32(right_eyes_cordinates)],
                    [0, 0, 0, 0],
                )

        # cv2.imwrite("deb_transform_target_image.jpg",transform_target_image)
        # cv2.imwrite("deb_dst_mask.jpg",dst_mask)
        # cv2.imwrite("deb_applied_image.jpg",applied_image)
        applied_image = blend_rgba_images_numba(
            applied_image, transform_target_image, dst_mask
        )
        cv2.imwrite("deb_final_transform_target_image.jpg", transform_target_image)

    return applied_image


def process_landmark_transform_pil(
    pil_image,
    pil_align_target_image,
    innner_mouth,
    innner_eyes,
    color_matching=False,
    transparent_background=False,
    add_align_mouth=False,
    add_align_eyes=False,
    blur_size=0,
):
    image = pil_to_bgr_image(pil_image)
    align_target_image = pil_to_bgr_image(pil_align_target_image)
    cv_result = process_landmark_transform(
        image,
        align_target_image,
        innner_mouth,
        innner_eyes,
        color_matching,
        transparent_background,
        add_align_mouth,
        add_align_eyes,
        blur_size,
    )
    if transparent_background:
        return Image.fromarray(cv2.cvtColor(cv_result, cv2.COLOR_BGRA2RGBA))
    else:
        return Image.fromarray(cv2.cvtColor(cv_result, cv2.COLOR_BGR2RGB))


if __name__ == "__main__":
    # image = Image.open('examples/00002062.jpg')
    # align_target = Image.open('examples/02316230.jpg')
    image = cv2.imread("examples/02316230.jpg")  # 元画像
    align_target = cv2.imread("examples/00003245_00.jpg")  # 目標画像
    result_img = process_landmark_transform(image, align_target)

    cv2.imshow("Transformed Image", result_img)
    cv2.waitKey(0)
    cv2.destroyAllWindows()

    cv2.imwrite("align.png", result_img)