from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from pathlib import Path
import time

import numpy as np

import os
import sys

__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))

os.environ['FLAGS_allocator_strategy'] = 'auto_growth'

import cv2
import json
import torch
from tools.engine import Config
from tools.utility import ArgsParser
from tools.utils.ckpt import load_ckpt
from tools.utils.logging import get_logger
from tools.utils.utility import get_image_file_list

logger = get_logger()

root_dir = Path(__file__).resolve().parent
DEFAULT_CFG_PATH_DET = str(root_dir / '../configs/det/dbnet/repvit_db.yml')

MODEL_NAME_DET = './openocr_det_repvit_ch.pth'  # 模型文件名称
DOWNLOAD_URL_DET = 'https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_det_repvit_ch.pth'  # 模型文件 URL


def check_and_download_model(model_name: str, url: str):
    """
    检查预训练模型是否存在,若不存在则从指定 URL 下载到固定缓存目录。

    Args:
        model_name (str): 模型文件的名称,例如 "model.pt"
        url (str): 模型文件的下载地址

    Returns:
        str: 模型文件的完整路径
    """
    if os.path.exists(model_name):
        return model_name

    # 固定缓存路径为用户主目录下的 ".cache/openocr"
    cache_dir = Path.home() / '.cache' / 'openocr'
    model_path = cache_dir / model_name

    # 如果模型文件已存在,直接返回路径
    if model_path.exists():
        logger.info(f'Model already exists at: {model_path}')
        return str(model_path)

    # 如果文件不存在,下载模型
    logger.info(f'Model not found. Downloading from {url}...')

    # 创建缓存目录(如果不存在)
    cache_dir.mkdir(parents=True, exist_ok=True)

    try:
        # 下载文件
        import urllib.request
        with urllib.request.urlopen(url) as response, open(model_path,
                                                           'wb') as out_file:
            out_file.write(response.read())
        logger.info(f'Model downloaded and saved at: {model_path}')
        return str(model_path)

    except Exception as e:
        logger.error(f'Error downloading the model: {e}')
        # 提示用户手动下载
        logger.error(
            f'Unable to download the model automatically. '
            f'Please download the model manually from the following URL:\n{url}\n'
            f'and save it to: {model_name} or {model_path}')
        raise RuntimeError(
            f'Failed to download the model. Please download it manually from {url} '
            f'and save it to {model_path}') from e


def replace_batchnorm(net):
    for child_name, child in net.named_children():
        if hasattr(child, 'fuse'):
            fused = child.fuse()
            setattr(net, child_name, fused)
            replace_batchnorm(fused)
        elif isinstance(child, torch.nn.BatchNorm2d):
            setattr(net, child_name, torch.nn.Identity())
        else:
            replace_batchnorm(child)


def padding_image(img, size=(640, 640)):
    """
    Padding an image using OpenCV:
    - If the image is smaller than the target size, pad it to 640x640.
    - If the image is larger than the target size, split it into multiple 640x640 images and record positions.

    :param image_path: Path to the input image.
    :param output_dir: Directory to save the output images.
    :param size: The target size for padding or splitting (default 640x640).
    :return: List of tuples containing the coordinates of the top-left corner of each cropped 640x640 image.
    """

    img_height, img_width = img.shape[:2]
    target_width, target_height = size

    # If image is smaller than target size, pad the image to 640x640

    # Calculate padding amounts (top, bottom, left, right)
    pad_top = 0
    pad_bottom = target_height - img_height
    pad_left = 0
    pad_right = target_width - img_width

    # Pad the image (white padding, border type: constant)
    padded_img = cv2.copyMakeBorder(img,
                                    pad_top,
                                    pad_bottom,
                                    pad_left,
                                    pad_right,
                                    cv2.BORDER_CONSTANT,
                                    value=[0, 0, 0])

    # Return the padded area positions (top-left and bottom-right coordinates of the original image)
    return padded_img


def resize_image(img, size=(640, 640), over_lap=64):
    """
    Resize an image using OpenCV:
    - If the image is smaller than the target size, pad it to 640x640.
    - If the image is larger than the target size, split it into multiple 640x640 images and record positions.

    :param image_path: Path to the input image.
    :param output_dir: Directory to save the output images.
    :param size: The target size for padding or splitting (default 640x640).
    :return: List of tuples containing the coordinates of the top-left corner of each cropped 640x640 image.
    """

    img_height, img_width = img.shape[:2]
    target_width, target_height = size

    # If image is smaller than target size, pad the image to 640x640
    if img_width <= target_width and img_height <= target_height:
        # Calculate padding amounts (top, bottom, left, right)
        if img_width == target_width and img_height == target_height:
            return [img], [[0, 0, img_width, img_height]]
        padded_img = padding_image(img, size)

        # Return the padded area positions (top-left and bottom-right coordinates of the original image)
        return [padded_img], [[0, 0, img_width, img_height]]

    img_height, img_width = img.shape[:2]
    # If image is larger than or equal to target size, crop it into 640x640 tiles
    crop_positions = []
    count = 0
    cropped_img_list = []
    for top in range(0, img_height - over_lap, target_height - over_lap):
        for left in range(0, img_width - over_lap, target_width - over_lap):
            # Calculate the bottom and right boundaries for the crop
            right = min(left + target_width, img_width)
            bottom = min(top + target_height, img_height)
            if right >= img_width:
                right = img_width
                left = max(0, right - target_width)
            if bottom >= img_height:
                bottom = img_height
                top = max(0, bottom - target_height)
            # Crop the image
            cropped_img = img[top:bottom, left:right]
            if bottom - top < target_height or right - left < target_width:
                cropped_img = padding_image(cropped_img, size)
            count += 1
            cropped_img_list.append(cropped_img)

            # Record the position of the cropped image
            crop_positions.append([left, top, right, bottom])

    return cropped_img_list, crop_positions


def restore_preds(preds, crop_positions, original_size):

    restored_pred = torch.zeros((1, 1, original_size[0], original_size[1]),
                                dtype=preds.dtype,
                                device=preds.device)
    count = 0
    for cropped_pred, (left, top, right, bottom) in zip(preds, crop_positions):

        crop_height = bottom - top
        crop_width = right - left

        corp_vis_img = cropped_pred[:, :crop_height, :crop_width]
        mask = corp_vis_img > 0.3
        count += 1
        restored_pred[:, :, top:top + crop_height, left:left +
                      crop_width] += mask[:, :crop_height, :crop_width].to(
                          preds.dtype)

    return restored_pred


def draw_det_res(dt_boxes, img, img_name, save_path):
    src_im = img
    for box in dt_boxes:
        box = np.array(box).astype(np.int32).reshape((-1, 1, 2))
        cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    save_path = os.path.join(save_path, os.path.basename(img_name))
    cv2.imwrite(save_path, src_im)


def set_device(device, numId=0):
    if device == 'gpu' and torch.cuda.is_available():
        device = torch.device(f'cuda:{numId}')
    else:
        logger.info('GPU is not available, using CPU.')
        device = torch.device('cpu')
    return device


class OpenDetector(object):

    def __init__(self, config=None, numId=0):
        """
        初始化函数。

        Args:
            config (dict, optional): 配置文件,默认为None。如果为None,则使用默认配置文件。
            numId (int, optional): 设备编号,默认为0。

        Returns:
            None

        Raises:
            无
        """

        if config is None:
            config = Config(DEFAULT_CFG_PATH_DET).cfg

        if not os.path.exists(config['Global']['pretrained_model']):
            config['Global']['pretrained_model'] = check_and_download_model(
                MODEL_NAME_DET, DOWNLOAD_URL_DET)

        from opendet.modeling import build_model as build_det_model
        from opendet.postprocess import build_post_process
        from opendet.preprocess import create_operators, transform
        self.transform = transform
        global_config = config['Global']

        # build model
        self.model = build_det_model(config['Architecture'])
        self.model.eval()
        load_ckpt(self.model, config)
        replace_batchnorm(self.model.backbone)
        self.device = set_device(config['Global']['device'], numId=numId)
        self.model.to(device=self.device)

        # create data ops
        transforms = []
        for op in config['Eval']['dataset']['transforms']:
            op_name = list(op)[0]
            if 'Label' in op_name:
                continue
            elif op_name == 'KeepKeys':
                op[op_name]['keep_keys'] = ['image', 'shape']
            transforms.append(op)

        self.ops = create_operators(transforms, global_config)

        # build post process
        self.post_process_class = build_post_process(config['PostProcess'],
                                                     global_config)

    def crop_infer(
        self,
        img_path=None,
        img_numpy_list=None,
        img_numpy=None,
    ):
        if img_numpy is not None:
            img_numpy_list = [img_numpy]
            num_img = 1
        elif img_path is not None:
            num_img = len(img_path)
        elif img_numpy_list is not None:
            num_img = len(img_numpy_list)
        else:
            raise Exception('No input image path or numpy array.')
        results = []
        for img_idx in range(num_img):
            if img_numpy_list is not None:
                img = img_numpy_list[img_idx]
                data = {'image': img}
            elif img_path is not None:
                with open(img_path[img_idx], 'rb') as f:
                    img = f.read()
                    data = {'image': img}
                data = self.transform(data, self.ops[:1])
            src_img_ori = data['image']
            img_height, img_width = src_img_ori.shape[:2]

            target_size = 640
            over_lap = 64
            if img_height > img_width:
                r_h = target_size * 2 - over_lap
                r_w = img_width * (target_size * 2 - over_lap) // img_height
            else:
                r_w = target_size * 2 - over_lap
                r_h = img_height * (target_size * 2 - over_lap) // img_width
            src_img = cv2.resize(src_img_ori, (r_w, r_h))
            shape_list_ori = np.array([[
                img_height, img_width,
                float(r_h) / img_height,
                float(r_w) / img_width
            ]])
            img_height, img_width = src_img.shape[:2]
            cropped_img_list, crop_positions = resize_image(src_img,
                                                            size=(target_size,
                                                                  target_size),
                                                            over_lap=over_lap)

            image_list = []
            shape_list = []
            for img in cropped_img_list:
                batch_i = self.transform({'image': img}, self.ops[-3:-1])
                image_list.append(batch_i['image'])
                shape_list.append([640, 640, 1, 1])
            images = np.array(image_list)
            shape_list = np.array(shape_list)
            images = torch.from_numpy(images).to(device=self.device)
            with torch.no_grad():
                t_start = time.time()
                preds = self.model(images)
                t_cost = time.time() - t_start

            preds['maps'] = restore_preds(preds['maps'], crop_positions,
                                          (img_height, img_width))
            post_result = self.post_process_class(preds, shape_list_ori)
            info = {'boxes': post_result[0]['points'], 'elapse': t_cost}
            results.append(info)
        return results

    def __call__(self,
                 img_path=None,
                 img_numpy_list=None,
                 img_numpy=None,
                 return_mask=False):
        """
        对输入图像进行处理,并返回处理结果。

        Args:
            img_path (str, optional): 图像文件路径。默认为 None。
            img_numpy_list (list, optional): 图像数据列表,每个元素为 numpy 数组。默认为 None。
            img_numpy (numpy.ndarray, optional): 图像数据,numpy 数组格式。默认为 None。

        Returns:
            list: 包含处理结果的列表。每个元素为一个字典,包含 'boxes' 和 'elapse' 两个键。
                'boxes' 的值为检测到的目标框点集,'elapse' 的值为处理时间。

        Raises:
            Exception: 若没有提供图像路径或 numpy 数组,则抛出异常。

        """

        if img_numpy is not None:
            img_numpy_list = [img_numpy]
            num_img = 1
        elif img_path is not None:
            img_path = get_image_file_list(img_path)
            num_img = len(img_path)
        elif img_numpy_list is not None:
            num_img = len(img_numpy_list)
        else:
            raise Exception('No input image path or numpy array.')
        results = []
        for img_idx in range(num_img):
            if img_numpy_list is not None:
                img = img_numpy_list[img_idx]
                data = {'image': img}
            elif img_path is not None:
                with open(img_path[img_idx], 'rb') as f:
                    img = f.read()
                    data = {'image': img}
                data = self.transform(data, self.ops[:1])
            batch = self.transform(data, self.ops[1:])

            images = np.expand_dims(batch[0], axis=0)
            shape_list = np.expand_dims(batch[1], axis=0)
            images = torch.from_numpy(images).to(device=self.device)
            with torch.no_grad():
                t_start = time.time()
                preds = self.model(images)
                t_cost = time.time() - t_start
            post_result = self.post_process_class(preds, shape_list)

            info = {'boxes': post_result[0]['points'], 'elapse': t_cost}
            if return_mask:
                if isinstance(preds['maps'], torch.Tensor):
                    mask = preds['maps'].detach().cpu().numpy()
                else:
                    mask = preds['maps']
                info['mask'] = mask
            results.append(info)
        return results


@torch.no_grad()
def main(cfg):
    is_visualize = cfg['Global'].get('is_visualize', False)
    model = OpenDetector(cfg)

    save_res_path = './det_results/'
    if not os.path.exists(save_res_path):
        os.makedirs(save_res_path)
    sample_num = 0
    with open(save_res_path + '/det_results.txt', 'wb') as fout:
        for file in get_image_file_list(cfg['Global']['infer_img']):
            preds_result = model(img_path=file)[0]
            logger.info('{} infer_img: {}, time cost: {}'.format(
                sample_num, file, preds_result['elapse']))
            boxes = preds_result['boxes']
            dt_boxes_json = []
            for box in boxes:
                tmp_json = {}
                tmp_json['points'] = np.array(box).tolist()
                dt_boxes_json.append(tmp_json)
            if is_visualize:
                src_img = cv2.imread(file)
                draw_det_res(boxes, src_img, file, save_res_path)
                logger.info('The detected Image saved in {}'.format(
                    os.path.join(save_res_path, os.path.basename(file))))
            otstr = file + '\t' + json.dumps(dt_boxes_json) + '\n'
            logger.info('results: {}'.format(json.dumps(dt_boxes_json)))
            fout.write(otstr.encode())
            sample_num += 1
        logger.info(
            f"Results saved to {os.path.join(save_res_path, 'det_results.txt')}.)"
        )

    logger.info('success!')


if __name__ == '__main__':
    FLAGS = ArgsParser().parse_args()
    cfg = Config(FLAGS.config)
    FLAGS = vars(FLAGS)
    opt = FLAGS.pop('opt')
    cfg.merge_dict(FLAGS)
    cfg.merge_dict(opt)
    main(cfg.cfg)