from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
from pathlib import Path
import sys

__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))

os.environ['FLAGS_allocator_strategy'] = 'auto_growth'
import argparse
import numpy as np
import copy
import time
import cv2
import json
from PIL import Image
from tools.utils.utility import get_image_file_list, check_and_read
from tools.infer_rec import OpenRecognizer
from tools.infer_det import OpenDetector
from tools.engine import Config
from tools.infer.utility import get_rotate_crop_image, get_minarea_rect_crop, draw_ocr_box_txt
from tools.utils.logging import get_logger

root_dir = Path(__file__).resolve().parent
DEFAULT_CFG_PATH_DET = str(root_dir / '../configs/det/dbnet/repvit_db.yml')
DEFAULT_CFG_PATH_REC_SERVER = str(root_dir /
                                  '../configs/det/svtrv2/svtrv2_ch.yml')
DEFAULT_CFG_PATH_REC = str(root_dir / '../configs/rec/svtrv2/repsvtr_ch.yml')

logger = get_logger()

MODEL_NAME_DET = './openocr_det_repvit_ch.pth'  # 模型文件名称
DOWNLOAD_URL_DET = 'https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_det_repvit_ch.pth'  # 模型文件 URL
MODEL_NAME_REC = './openocr_repsvtr_ch.pth'  # 模型文件名称
DOWNLOAD_URL_REC = 'https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_repsvtr_ch.pth'  # 模型文件 URL
MODEL_NAME_REC_SERVER = './openocr_svtrv2_ch.pth'  # 模型文件名称
DOWNLOAD_URL_REC_SERVER = 'https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_svtrv2_ch.pth'  # 模型文件 URL


def check_and_download_model(model_name: str, url: str):
    """
    检查预训练模型是否存在,若不存在则从指定 URL 下载到固定缓存目录。

    Args:
        model_name (str): 模型文件的名称,例如 "model.pt"
        url (str): 模型文件的下载地址

    Returns:
        str: 模型文件的完整路径
    """
    if os.path.exists(model_name):
        return model_name

    # 固定缓存路径为用户主目录下的 ".cache/openocr"
    cache_dir = Path.home() / '.cache' / 'openocr'
    model_path = cache_dir / model_name

    # 如果模型文件已存在,直接返回路径
    if model_path.exists():
        logger.info(f'Model already exists at: {model_path}')
        return str(model_path)

    # 如果文件不存在,下载模型
    logger.info(f'Model not found. Downloading from {url}...')

    # 创建缓存目录(如果不存在)
    cache_dir.mkdir(parents=True, exist_ok=True)

    try:
        # 下载文件
        import urllib.request
        with urllib.request.urlopen(url) as response, open(model_path,
                                                           'wb') as out_file:
            out_file.write(response.read())
        logger.info(f'Model downloaded and saved at: {model_path}')
        return str(model_path)

    except Exception as e:
        logger.error(f'Error downloading the model: {e}')
        # 提示用户手动下载
        logger.error(
            f'Unable to download the model automatically. '
            f'Please download the model manually from the following URL:\n{url}\n'
            f'and save it to: {model_name} or {model_path}')
        raise RuntimeError(
            f'Failed to download the model. Please download it manually from {url} '
            f'and save it to {model_path}') from e


def check_and_download_font(font_path):
    if not os.path.exists(font_path):
        cache_dir = Path.home() / '.cache' / 'openocr'
        font_path = str(cache_dir / font_path)
        if os.path.exists(font_path):
            return font_path
        logger.info(f"Downloading '{font_path}' ...")
        try:
            import urllib.request
            font_url = 'https://shuiche-shop.oss-cn-chengdu.aliyuncs.com/fonts/simfang.ttf'
            urllib.request.urlretrieve(font_url, font_path)
            logger.info(f'Downloading font success: {font_path}')
        except Exception as e:
            logger.info(f'Downloading font error: {e}')
    return font_path


def sorted_boxes(dt_boxes):
    """
    Sort text boxes in order from top to bottom, left to right
    args:
        dt_boxes(array):detected text boxes with shape [4, 2]
    return:
        sorted boxes(array) with shape [4, 2]
    """
    num_boxes = dt_boxes.shape[0]
    sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
    _boxes = list(sorted_boxes)

    for i in range(num_boxes - 1):
        for j in range(i, -1, -1):
            if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and (
                    _boxes[j + 1][0][0] < _boxes[j][0][0]):
                tmp = _boxes[j]
                _boxes[j] = _boxes[j + 1]
                _boxes[j + 1] = tmp
            else:
                break
    return _boxes


class OpenOCR(object):

    def __init__(self, mode='mobile', drop_score=0.5, det_box_type='quad'):
        """
        初始化函数,用于初始化OCR引擎的相关配置和组件。

        Args:
            mode (str, optional): 运行模式,可选值为'mobile'或'server'。默认为'mobile'。
            drop_score (float, optional): 检测框的置信度阈值,低于该阈值的检测框将被丢弃。默认为0.5。
            det_box_type (str, optional): 检测框的类型,可选值为'quad' and 'poly'。默认为'quad'。

        Returns:
            无返回值。

        """
        cfg_det = Config(DEFAULT_CFG_PATH_DET).cfg  # mobile model
        model_dir = check_and_download_model(MODEL_NAME_DET, DOWNLOAD_URL_DET)
        cfg_det['Global']['pretrained_model'] = model_dir
        if mode == 'server':
            cfg_rec = Config(DEFAULT_CFG_PATH_REC_SERVER).cfg  # server model
            model_dir = check_and_download_model(MODEL_NAME_REC_SERVER,
                                                 DOWNLOAD_URL_REC_SERVER)
        else:
            cfg_rec = Config(DEFAULT_CFG_PATH_REC).cfg  # mobile model
            model_dir = check_and_download_model(MODEL_NAME_REC,
                                                 DOWNLOAD_URL_REC)
        cfg_rec['Global']['pretrained_model'] = model_dir
        self.text_detector = OpenDetector(cfg_det)
        self.text_recognizer = OpenRecognizer(cfg_rec)
        self.det_box_type = det_box_type
        self.drop_score = drop_score

        self.crop_image_res_index = 0

    def draw_crop_rec_res(self, output_dir, img_crop_list, rec_res):
        os.makedirs(output_dir, exist_ok=True)
        bbox_num = len(img_crop_list)
        for bno in range(bbox_num):
            cv2.imwrite(
                os.path.join(output_dir,
                             f'mg_crop_{bno+self.crop_image_res_index}.jpg'),
                img_crop_list[bno],
            )
        self.crop_image_res_index += bbox_num

    def infer_single_image(self,
                           img_numpy,
                           ori_img,
                           crop_infer=False,
                           rec_batch_num=6,
                           return_mask=False):
        start = time.time()
        if crop_infer:
            dt_boxes = self.text_detector.crop_infer(
                img_numpy=img_numpy)[0]['boxes']
        else:
            det_res = self.text_detector(img_numpy=img_numpy,
                                         return_mask=return_mask)[0]
            dt_boxes = det_res['boxes']
        # logger.info(dt_boxes)
        det_time_cost = time.time() - start

        if dt_boxes is None:
            return None, None, None

        img_crop_list = []

        dt_boxes = sorted_boxes(dt_boxes)

        for bno in range(len(dt_boxes)):
            tmp_box = np.array(copy.deepcopy(dt_boxes[bno])).astype(np.float32)
            if self.det_box_type == 'quad':
                img_crop = get_rotate_crop_image(ori_img, tmp_box)
            else:
                img_crop = get_minarea_rect_crop(ori_img, tmp_box)
            img_crop_list.append(img_crop)

        start = time.time()
        rec_res = self.text_recognizer(img_numpy_list=img_crop_list,
                                       batch_num=rec_batch_num)
        rec_time_cost = time.time() - start

        filter_boxes, filter_rec_res = [], []
        rec_time_cost_sig = 0.0
        for box, rec_result in zip(dt_boxes, rec_res):
            text, score = rec_result['text'], rec_result['score']
            rec_time_cost_sig += rec_result['elapse']
            if score >= self.drop_score:
                filter_boxes.append(box)
                filter_rec_res.append([text, score])

        avg_rec_time_cost = rec_time_cost_sig / len(dt_boxes) if len(
            dt_boxes) > 0 else 0.0
        if return_mask:
            return filter_boxes, filter_rec_res, {
                'time_cost': det_time_cost + rec_time_cost,
                'detection_time': det_time_cost,
                'recognition_time': rec_time_cost,
                'avg_rec_time_cost': avg_rec_time_cost
            }, det_res['mask']

        return filter_boxes, filter_rec_res, {
            'time_cost': det_time_cost + rec_time_cost,
            'detection_time': det_time_cost,
            'recognition_time': rec_time_cost,
            'avg_rec_time_cost': avg_rec_time_cost
        }

    def __call__(self,
                 img_path=None,
                 save_dir='e2e_results/',
                 is_visualize=False,
                 img_numpy=None,
                 rec_batch_num=6,
                 crop_infer=False,
                 return_mask=False):
        """
        img_path: str, optional, default=None
            Path to the directory containing images or the image filename.
        save_dir: str, optional, default='e2e_results/'
            Directory to save prediction and visualization results. Defaults to a subfolder in img_path.
        is_visualize: bool, optional, default=False
            Visualize the results.
        img_numpy: numpy or list[numpy], optional, default=None
            numpy of an image or List of numpy arrays representing images.
        rec_batch_num: int, optional, default=6
            Batch size for text recognition.
        crop_infer: bool, optional, default=False
            Whether to use crop inference.
        """

        if img_numpy is None and img_path is None:
            raise ValueError('img_path and img_numpy cannot be both None.')
        if img_numpy is not None:
            if not isinstance(img_numpy, list):
                img_numpy = [img_numpy]
            results = []
            time_dicts = []
            for index, img in enumerate(img_numpy):
                ori_img = img.copy()
                if return_mask:
                    dt_boxes, rec_res, time_dict, mask = self.infer_single_image(
                        img_numpy=img,
                        ori_img=ori_img,
                        crop_infer=crop_infer,
                        rec_batch_num=rec_batch_num,
                        return_mask=return_mask)
                else:
                    dt_boxes, rec_res, time_dict = self.infer_single_image(
                        img_numpy=img,
                        ori_img=ori_img,
                        crop_infer=crop_infer,
                        rec_batch_num=rec_batch_num)
                if dt_boxes is None:
                    results.append([])
                    time_dicts.append({})
                    continue
                res = [{
                    'transcription': rec_res[i][0],
                    'points': np.array(dt_boxes[i]).tolist(),
                    'score': rec_res[i][1],
                } for i in range(len(dt_boxes))]
                results.append(res)
                time_dicts.append(time_dict)
            if return_mask:
                return results, time_dicts, mask
            return results, time_dicts

        image_file_list = get_image_file_list(img_path)
        save_results = []
        time_dicts_return = []
        for idx, image_file in enumerate(image_file_list):
            img, flag_gif, flag_pdf = check_and_read(image_file)
            if not flag_gif and not flag_pdf:
                img = cv2.imread(image_file)
            if not flag_pdf:
                if img is None:
                    return None
                imgs = [img]
            else:
                imgs = img
            logger.info(
                f'Processing {idx+1}/{len(image_file_list)}: {image_file}')

            res_list = []
            time_dicts = []
            for index, img_numpy in enumerate(imgs):
                ori_img = img_numpy.copy()
                dt_boxes, rec_res, time_dict = self.infer_single_image(
                    img_numpy=img_numpy,
                    ori_img=ori_img,
                    crop_infer=crop_infer,
                    rec_batch_num=rec_batch_num)
                if dt_boxes is None:
                    res_list.append([])
                    time_dicts.append({})
                    continue
                res = [{
                    'transcription': rec_res[i][0],
                    'points': np.array(dt_boxes[i]).tolist(),
                    'score': rec_res[i][1],
                } for i in range(len(dt_boxes))]
                res_list.append(res)
                time_dicts.append(time_dict)

            for index, (res, time_dict) in enumerate(zip(res_list,
                                                         time_dicts)):

                if len(res) > 0:
                    logger.info(f'Results: {res}.')
                    logger.info(f'Time cost: {time_dict}.')
                else:
                    logger.info('No text detected.')

                if len(res_list) > 1:
                    save_pred = (os.path.basename(image_file) + '_' +
                                 str(index) + '\t' +
                                 json.dumps(res, ensure_ascii=False) + '\n')
                else:
                    if len(res) > 0:
                        save_pred = (os.path.basename(image_file) + '\t' +
                                     json.dumps(res, ensure_ascii=False) +
                                     '\n')
                    else:
                        continue
                save_results.append(save_pred)
                time_dicts_return.append(time_dict)

                if is_visualize and len(res) > 0:
                    if idx == 0:
                        font_path = './simfang.ttf'
                        font_path = check_and_download_font(font_path)
                        os.makedirs(save_dir, exist_ok=True)
                        draw_img_save_dir = os.path.join(
                            save_dir, 'vis_results/')
                        os.makedirs(draw_img_save_dir, exist_ok=True)
                        logger.info(
                            f'Visualized results will be saved to {draw_img_save_dir}.'
                        )
                    dt_boxes = [res[i]['points'] for i in range(len(res))]
                    rec_res = [
                        res[i]['transcription'] for i in range(len(res))
                    ]
                    rec_score = [res[i]['score'] for i in range(len(res))]
                    image = Image.fromarray(
                        cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
                    boxes = dt_boxes
                    txts = [rec_res[i] for i in range(len(rec_res))]
                    scores = [rec_score[i] for i in range(len(rec_res))]

                    draw_img = draw_ocr_box_txt(
                        image,
                        boxes,
                        txts,
                        scores,
                        drop_score=self.drop_score,
                        font_path=font_path,
                    )
                    if flag_gif:
                        save_file = image_file[:-3] + 'png'
                    elif flag_pdf:
                        save_file = image_file.replace(
                            '.pdf', '_' + str(index) + '.png')
                    else:
                        save_file = image_file
                    cv2.imwrite(
                        os.path.join(draw_img_save_dir,
                                     os.path.basename(save_file)),
                        draw_img[:, :, ::-1],
                    )

        if save_results:
            os.makedirs(save_dir, exist_ok=True)
            with open(os.path.join(save_dir, 'system_results.txt'),
                      'w',
                      encoding='utf-8') as f:
                f.writelines(save_results)
            logger.info(
                f"Results saved to {os.path.join(save_dir, 'system_results.txt')}."
            )
            if is_visualize:
                logger.info(
                    f'Visualized results saved to {draw_img_save_dir}.')
            return save_results, time_dicts_return
        else:
            logger.info('No text detected.')
            return None, None


def main():
    parser = argparse.ArgumentParser(description='OpenOCR system')
    parser.add_argument(
        '--img_path',
        type=str,
        help='Path to the directory containing images or the image filename.')
    parser.add_argument(
        '--mode',
        type=str,
        default='mobile',
        help="Mode of the OCR system, e.g., 'mobile' or 'server'.")
    parser.add_argument(
        '--save_dir',
        type=str,
        default='e2e_results/',
        help='Directory to save prediction and visualization results. \
            Defaults to ./e2e_results/.')
    parser.add_argument('--is_vis',
                        action='store_true',
                        default=False,
                        help='Visualize the results.')
    parser.add_argument('--drop_score',
                        type=float,
                        default=0.5,
                        help='Score threshold for text recognition.')
    args = parser.parse_args()

    img_path = args.img_path
    mode = args.mode
    save_dir = args.save_dir
    is_visualize = args.is_vis
    drop_score = args.drop_score

    text_sys = OpenOCR(mode=mode, drop_score=drop_score,
                       det_box_type='quad')  # det_box_type: 'quad' or 'poly'
    text_sys(img_path=img_path, save_dir=save_dir, is_visualize=is_visualize)


if __name__ == '__main__':
    main()