# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os
import time
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.cuda.amp import autocast
import hydra

from omegaconf import DictConfig, OmegaConf
from hydra.utils import instantiate

from lightglue import LightGlue, SuperPoint, SIFT, ALIKED

import pycolmap

# from visdom import Visdom


from vggsfm.datasets.demo_loader import DemoLoader
from vggsfm.two_view_geo.estimate_preliminary import estimate_preliminary_cameras

try:
    import poselib
    from vggsfm.two_view_geo.estimate_preliminary import estimate_preliminary_cameras_poselib

    print("Poselib is available")
except:
    print("Poselib is not installed. Please disable use_poselib")

from vggsfm.utils.utils import (
    set_seed_and_print,
    farthest_point_sampling,
    calculate_index_mappings,
    switch_tensor_order,
)


def demo_fn(cfg):
    OmegaConf.set_struct(cfg, False)

    # Print configuration
    print("Model Config:", OmegaConf.to_yaml(cfg))

    torch.backends.cudnn.enabled = False
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True

    # Set seed
    seed_all_random_engines(cfg.seed)

    # Model instantiation
    model = instantiate(cfg.MODEL, _recursive_=False, cfg=cfg)

    device = "cuda" if torch.cuda.is_available() else "cpu"

    model = model.to(device)

    # Prepare test dataset
    test_dataset = DemoLoader(
        SCENE_DIR=cfg.SCENE_DIR, img_size=cfg.img_size, normalize_cameras=False, load_gt=cfg.load_gt, cfg=cfg
    )

    # if cfg.resume_ckpt:
    _VGGSFM_URL = "https://huggingface.co/facebook/VGGSfM/resolve/main/vggsfm_v2_0_0.bin"

    # Reload model
    checkpoint = torch.hub.load_state_dict_from_url(_VGGSFM_URL)
    model.load_state_dict(checkpoint, strict=True)
    print(f"Successfully resumed from {_VGGSFM_URL}")


    sequence_list = test_dataset.sequence_list

    for seq_name in sequence_list:
        print("*" * 50 + f" Testing on Scene {seq_name} " + "*" * 50)

        # Load the data
        batch, image_paths = test_dataset.get_data(sequence_name=seq_name, return_path=True)

        # Send to GPU
        images = batch["image"].to(device)
        crop_params = batch["crop_params"].to(device)


        # Unsqueeze to have batch size = 1
        images = images.unsqueeze(0)
        crop_params = crop_params.unsqueeze(0)

        batch_size = len(images)

        with torch.no_grad():
            # Run the model
            assert cfg.mixed_precision in ("None", "bf16", "fp16")
            if cfg.mixed_precision == "None":
                dtype = torch.float32
            elif cfg.mixed_precision == "bf16":
                dtype = torch.bfloat16
            elif cfg.mixed_precision == "fp16":
                dtype = torch.float16
            else:
                raise NotImplementedError(f"dtype {cfg.mixed_precision} is not supported now")

            predictions = run_one_scene(
                model,
                images,
                crop_params=crop_params,
                query_frame_num=cfg.query_frame_num,
                image_paths=image_paths,
                dtype=dtype,
                cfg=cfg,
            )

        pred_cameras_PT3D = predictions["pred_cameras_PT3D"]

    return predictions


def run_one_scene(model, images, crop_params=None, query_frame_num=3, image_paths=None, dtype=None, cfg=None):
    """
    images have been normalized to the range [0, 1] instead of [0, 255]
    """
    batch_num, frame_num, image_dim, height, width = images.shape
    device = images.device
    reshaped_image = images.reshape(batch_num * frame_num, image_dim, height, width)

    predictions = {}
    extra_dict = {}

    camera_predictor = model.camera_predictor
    track_predictor = model.track_predictor
    triangulator = model.triangulator

    # Find the query frames
    # First use DINO to find the most common frame among all the input frames
    # i.e., the one has highest (average) cosine similarity to all others
    # Then use farthest_point_sampling to find the next ones
    # The number of query frames is determined by query_frame_num

    with autocast(dtype=dtype):
        query_frame_indexes = find_query_frame_indexes(reshaped_image, camera_predictor, frame_num)

    raw_image_paths = image_paths
    image_paths = [os.path.basename(imgpath) for imgpath in image_paths]

    if cfg.center_order:
        # The code below switchs the first frame (frame 0) to the most common frame
        center_frame_index = query_frame_indexes[0]
        center_order = calculate_index_mappings(center_frame_index, frame_num, device=device)

        images, crop_params = switch_tensor_order([images, crop_params], center_order, dim=1)
        reshaped_image = switch_tensor_order([reshaped_image], center_order, dim=0)[0]

        image_paths = [image_paths[i] for i in center_order.cpu().numpy().tolist()]

        # Also update query_frame_indexes:
        query_frame_indexes = [center_frame_index if x == 0 else x for x in query_frame_indexes]
        query_frame_indexes[0] = 0

    # only pick query_frame_num
    query_frame_indexes = query_frame_indexes[:query_frame_num]

    # Prepare image feature maps for tracker
    fmaps_for_tracker = track_predictor.process_images_to_fmaps(images)

    # Predict tracks
    with autocast(dtype=dtype):
        pred_track, pred_vis, pred_score = predict_tracks(
            cfg.query_method,
            cfg.max_query_pts,
            track_predictor,
            images,
            fmaps_for_tracker,
            query_frame_indexes,
            frame_num,
            device,
            cfg,
        )

        if cfg.comple_nonvis:
            pred_track, pred_vis, pred_score = comple_nonvis_frames(
                track_predictor,
                images,
                fmaps_for_tracker,
                frame_num,
                device,
                pred_track,
                pred_vis,
                pred_score,
                200,
                cfg=cfg,
            )

    torch.cuda.empty_cache()

    # If necessary, force all the predictions at the padding areas as non-visible
    if crop_params is not None:
        boundaries = crop_params[:, :, -4:-2].abs().to(device)
        boundaries = torch.cat([boundaries, reshaped_image.shape[-1] - boundaries], dim=-1)
        hvis = torch.logical_and(
            pred_track[..., 1] >= boundaries[:, :, 1:2], pred_track[..., 1] <= boundaries[:, :, 3:4]
        )
        wvis = torch.logical_and(
            pred_track[..., 0] >= boundaries[:, :, 0:1], pred_track[..., 0] <= boundaries[:, :, 2:3]
        )
        force_vis = torch.logical_and(hvis, wvis)
        pred_vis = pred_vis * force_vis.float()

    # TODO: plot 2D matches
    if cfg.use_poselib:
        estimate_preliminary_cameras_fn = estimate_preliminary_cameras_poselib
    else:
        estimate_preliminary_cameras_fn = estimate_preliminary_cameras

    # Estimate preliminary_cameras by recovering fundamental/essential/homography matrix from 2D matches
    # By default, we use fundamental matrix estimation with 7p/8p+LORANSAC
    # All the operations are batched and differentiable (if necessary)
    # except when you enable use_poselib to save GPU memory
    _, preliminary_dict = estimate_preliminary_cameras_fn(
        pred_track,
        pred_vis,
        width,
        height,
        tracks_score=pred_score,
        max_error=cfg.fmat_thres,
        loopresidual=True,
        # max_ransac_iters=cfg.max_ransac_iters,
    )

    pose_predictions = camera_predictor(reshaped_image, batch_size=batch_num)

    pred_cameras = pose_predictions["pred_cameras"]

    # Conduct Triangulation and Bundle Adjustment
    (
        BA_cameras_PT3D,
        extrinsics_opencv,
        intrinsics_opencv,
        points3D,
        points3D_rgb,
        reconstruction,
        valid_frame_mask,
    ) = triangulator(
        pred_cameras,
        pred_track,
        pred_vis,
        images,
        preliminary_dict,
        image_paths=image_paths,
        crop_params=crop_params,
        pred_score=pred_score,
        fmat_thres=cfg.fmat_thres,
        BA_iters=cfg.BA_iters,
        max_reproj_error = cfg.max_reproj_error,
        init_max_reproj_error=cfg.init_max_reproj_error,
        cfg=cfg,
    )

    # if cfg.center_order:
    #     # NOTE we changed the image order previously, now we need to switch it back
    #     BA_cameras_PT3D = BA_cameras_PT3D[center_order]
    #     extrinsics_opencv = extrinsics_opencv[center_order]
    #     intrinsics_opencv = intrinsics_opencv[center_order]

    if cfg.filter_invalid_frame:
        raw_image_paths = np.array(raw_image_paths)[valid_frame_mask.cpu().numpy().tolist()].tolist()
        images = images[0][valid_frame_mask]
        
    predictions["pred_cameras_PT3D"] = BA_cameras_PT3D
    predictions["extrinsics_opencv"] = extrinsics_opencv
    predictions["intrinsics_opencv"] = intrinsics_opencv
    predictions["points3D"] = points3D
    predictions["points3D_rgb"] = points3D_rgb
    predictions["reconstruction"] = reconstruction
    predictions["images"] = images
    predictions["raw_image_paths"] = raw_image_paths
    return predictions


def predict_tracks(
    query_method,
    max_query_pts,
    track_predictor,
    images,
    fmaps_for_tracker,
    query_frame_indexes,
    frame_num,
    device,
    cfg=None,
):
    pred_track_list = []
    pred_vis_list = []
    pred_score_list = []

    for query_index in query_frame_indexes:
        print(f"Predicting tracks with query_index = {query_index}")

        # Find query_points at the query frame
        query_points = get_query_points(images[:, query_index], query_method, max_query_pts)

        # Switch so that query_index frame stays at the first frame
        # This largely simplifies the code structure of tracker
        new_order = calculate_index_mappings(query_index, frame_num, device=device)
        images_feed, fmaps_feed = switch_tensor_order([images, fmaps_for_tracker], new_order)

        # Feed into track predictor
        fine_pred_track, _, pred_vis, pred_score = track_predictor(images_feed, query_points, fmaps=fmaps_feed)

        # Switch back the predictions
        fine_pred_track, pred_vis, pred_score = switch_tensor_order([fine_pred_track, pred_vis, pred_score], new_order)

        # Append predictions for different queries
        pred_track_list.append(fine_pred_track)
        pred_vis_list.append(pred_vis)
        pred_score_list.append(pred_score)

    pred_track = torch.cat(pred_track_list, dim=2)
    pred_vis = torch.cat(pred_vis_list, dim=2)
    pred_score = torch.cat(pred_score_list, dim=2)

    return pred_track, pred_vis, pred_score


def comple_nonvis_frames(
    track_predictor,
    images,
    fmaps_for_tracker,
    frame_num,
    device,
    pred_track,
    pred_vis,
    pred_score,
    min_vis=500,
    cfg=None,
):
    # if a frame has too few visible inlier, use it as a query
    non_vis_frames = torch.nonzero((pred_vis.squeeze(0) > 0.05).sum(-1) < min_vis).squeeze(-1).tolist()
    last_query = -1
    while len(non_vis_frames) > 0:
        print("Processing non visible frames")
        print(non_vis_frames)
        if non_vis_frames[0] == last_query:
            print("The non vis frame still does not has enough 2D matches")
            pred_track_comple, pred_vis_comple, pred_score_comple = predict_tracks(
                "sp+sift+aliked",
                cfg.max_query_pts // 2,
                track_predictor,
                images,
                fmaps_for_tracker,
                non_vis_frames,
                frame_num,
                device,
                cfg,
            )
            # concat predictions
            pred_track = torch.cat([pred_track, pred_track_comple], dim=2)
            pred_vis = torch.cat([pred_vis, pred_vis_comple], dim=2)
            pred_score = torch.cat([pred_score, pred_score_comple], dim=2)
            break

        non_vis_query_list = [non_vis_frames[0]]
        last_query = non_vis_frames[0]
        pred_track_comple, pred_vis_comple, pred_score_comple = predict_tracks(
            cfg.query_method,
            cfg.max_query_pts,
            track_predictor,
            images,
            fmaps_for_tracker,
            non_vis_query_list,
            frame_num,
            device,
            cfg,
        )

        # concat predictions
        pred_track = torch.cat([pred_track, pred_track_comple], dim=2)
        pred_vis = torch.cat([pred_vis, pred_vis_comple], dim=2)
        pred_score = torch.cat([pred_score, pred_score_comple], dim=2)
        non_vis_frames = torch.nonzero((pred_vis.squeeze(0) > 0.05).sum(-1) < min_vis).squeeze(-1).tolist()
    return pred_track, pred_vis, pred_score


def find_query_frame_indexes(reshaped_image, camera_predictor, query_frame_num, image_size=336):
    # Downsample image to image_size x image_size
    # because we found it is unnecessary to use high resolution
    rgbs = F.interpolate(reshaped_image, (image_size, image_size), mode="bilinear", align_corners=True)
    rgbs = camera_predictor._resnet_normalize_image(rgbs)

    # Get the image features (patch level)
    frame_feat = camera_predictor.backbone(rgbs, is_training=True)
    frame_feat = frame_feat["x_norm_patchtokens"]
    frame_feat_norm = F.normalize(frame_feat, p=2, dim=1)

    # Compute the similiarty matrix
    frame_feat_norm = frame_feat_norm.permute(1, 0, 2)
    similarity_matrix = torch.bmm(frame_feat_norm, frame_feat_norm.transpose(-1, -2))
    similarity_matrix = similarity_matrix.mean(dim=0)
    distance_matrix = 100 - similarity_matrix.clone()

    # Ignore self-pairing
    similarity_matrix.fill_diagonal_(-100)

    similarity_sum = similarity_matrix.sum(dim=1)

    # Find the most common frame
    most_common_frame_index = torch.argmax(similarity_sum).item()

    # Conduct FPS sampling
    # Starting from the most_common_frame_index,
    # try to find the farthest frame,
    # then the farthest to the last found frame
    # (frames are not allowed to be found twice)
    fps_idx = farthest_point_sampling(distance_matrix, query_frame_num, most_common_frame_index)

    return fps_idx


def get_query_points(query_image, query_method, max_query_num=4096, det_thres=0.005):
    # Run superpoint and sift on the target frame
    # Feel free to modify for your own

    methods = query_method.split("+")
    pred_points = []

    for method in methods:
        if "sp" in method:
            extractor = SuperPoint(max_num_keypoints=max_query_num, detection_threshold=det_thres).cuda().eval()
        elif "sift" in method:
            extractor = SIFT(max_num_keypoints=max_query_num).cuda().eval()
        elif "aliked" in method:
            extractor = ALIKED(max_num_keypoints=max_query_num, detection_threshold=det_thres).cuda().eval()
        else:
            raise NotImplementedError(f"query method {method} is not supprted now")

        query_points = extractor.extract(query_image)["keypoints"]
        pred_points.append(query_points)

    query_points = torch.cat(pred_points, dim=1)

    if query_points.shape[1] > max_query_num:
        random_point_indices = torch.randperm(query_points.shape[1])[:max_query_num]
        query_points = query_points[:, random_point_indices, :]

    return query_points


def seed_all_random_engines(seed: int) -> None:
    np.random.seed(seed)
    torch.manual_seed(seed)
    random.seed(seed)