import os
from pathlib import Path
import sys
import time

__dir__ = os.path.dirname(os.path.abspath(__file__))

sys.path.append(__dir__)
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))

import numpy as np
import torch
from torchvision import transforms as T
from torchvision.transforms import functional as F
from tools.engine import Config
from tools.utility import ArgsParser
from tools.utils.ckpt import load_ckpt
from tools.utils.logging import get_logger
from tools.utils.utility import get_image_file_list
from tools.infer_det import replace_batchnorm

logger = get_logger()

root_dir = Path(__file__).resolve().parent
DEFAULT_CFG_PATH_REC_SERVER = str(root_dir /
                                  '../configs/det/svtrv2/svtrv2_ch.yml')
DEFAULT_CFG_PATH_REC = str(root_dir / '../configs/rec/svtrv2/repsvtr_ch.yml')
DEFAULT_DICT_PATH_REC = str(root_dir / './utils/ppocr_keys_v1.txt')

MODEL_NAME_REC = './openocr_repsvtr_ch.pth'  # 模型文件名称
DOWNLOAD_URL_REC = 'https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_repsvtr_ch.pth'  # 模型文件 URL
MODEL_NAME_REC_SERVER = './openocr_svtrv2_ch.pth'  # 模型文件名称
DOWNLOAD_URL_REC_SERVER = 'https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_svtrv2_ch.pth'  # 模型文件 URL


def check_and_download_model(model_name: str, url: str):
    """
    检查预训练模型是否存在,若不存在则从指定 URL 下载到固定缓存目录。

    Args:
        model_name (str): 模型文件的名称,例如 "model.pt"
        url (str): 模型文件的下载地址

    Returns:
        str: 模型文件的完整路径
    """
    if os.path.exists(model_name):
        return model_name

    # 固定缓存路径为用户主目录下的 ".cache/openocr"
    cache_dir = Path.home() / '.cache' / 'openocr'
    model_path = cache_dir / model_name

    # 如果模型文件已存在,直接返回路径
    if model_path.exists():
        logger.info(f'Model already exists at: {model_path}')
        return str(model_path)

    # 如果文件不存在,下载模型
    logger.info(f'Model not found. Downloading from {url}...')

    # 创建缓存目录(如果不存在)
    cache_dir.mkdir(parents=True, exist_ok=True)

    try:
        # 下载文件
        import urllib.request
        with urllib.request.urlopen(url) as response, open(model_path,
                                                           'wb') as out_file:
            out_file.write(response.read())
        logger.info(f'Model downloaded and saved at: {model_path}')
        return str(model_path)

    except Exception as e:
        logger.error(f'Error downloading the model: {e}')
        # 提示用户手动下载
        logger.error(
            f'Unable to download the model automatically. '
            f'Please download the model manually from the following URL:\n{url}\n'
            f'and save it to: {model_name} or {model_path}')
        raise RuntimeError(
            f'Failed to download the model. Please download it manually from {url} '
            f'and save it to {model_path}') from e


class RatioRecTVReisze(object):

    def __init__(self, cfg):
        self.max_ratio = cfg['Eval']['loader'].get('max_ratio', 12)
        self.base_shape = cfg['Eval']['dataset'].get(
            'base_shape', [[64, 64], [96, 48], [112, 40], [128, 32]])
        self.base_h = cfg['Eval']['dataset'].get('base_h', 32)
        self.interpolation = T.InterpolationMode.BICUBIC
        transforms = []
        transforms.extend([
            T.ToTensor(),
            T.Normalize(0.5, 0.5),
        ])
        self.transforms = T.Compose(transforms)
        self.ceil = cfg['Eval']['dataset'].get('ceil', False),

    def __call__(self, data):
        img = data['image']
        imgH = self.base_h
        w, h = img.size
        if self.ceil:
            gen_ratio = int(float(w) / float(h)) + 1
        else:
            gen_ratio = max(1, round(float(w) / float(h)))
        ratio_resize = min(gen_ratio, self.max_ratio)
        imgW, imgH = self.base_shape[ratio_resize -
                                     1] if ratio_resize <= 4 else [
                                         self.base_h *
                                         ratio_resize, self.base_h
                                     ]
        resized_w = imgW
        resized_image = F.resize(img, (imgH, resized_w),
                                 interpolation=self.interpolation)
        img = self.transforms(resized_image)
        data['image'] = img
        return data


def build_rec_process(cfg):
    transforms = []
    ratio_resize_flag = True
    for op in cfg['Eval']['dataset']['transforms']:
        op_name = list(op)[0]
        if 'Resize' in op_name:
            ratio_resize_flag = False
        if 'Label' in op_name:
            continue
        elif op_name in ['RecResizeImg']:
            op[op_name]['infer_mode'] = True
        elif op_name == 'KeepKeys':
            if cfg['Architecture']['algorithm'] in ['SAR', 'RobustScanner']:
                if 'valid_ratio' in op[op_name]['keep_keys']:
                    op[op_name]['keep_keys'] = ['image', 'valid_ratio']
                else:
                    op[op_name]['keep_keys'] = ['image']
            else:
                op[op_name]['keep_keys'] = ['image']
        transforms.append(op)
    return transforms, ratio_resize_flag


def set_device(device, numId=0):
    if device == 'gpu' and torch.cuda.is_available():
        device = torch.device(f'cuda:{numId}')
    else:
        logger.info('GPU is not available, using CPU.')
        device = torch.device('cpu')
    return device


class OpenRecognizer(object):

    def __init__(self, config=None, mode='mobile', numId=0):
        """
        初始化方法。

        Args:
            config (dict, optional): 配置信息。默认为None。
            mode (str, optional): 模式,'server' 或 'mobile'。默认为'mobile'。
            numId (int, optional): 设备编号。默认为0。

        Returns:
            None

        Raises:
            无

        """
        if config is None:
            if mode == 'server':
                config = Config(
                    DEFAULT_CFG_PATH_REC_SERVER).cfg  # server model
                if not os.path.exists(config['Global']['pretrained_model']):
                    model_dir = check_and_download_model(
                        MODEL_NAME_REC_SERVER, DOWNLOAD_URL_REC_SERVER)
            else:
                config = Config(DEFAULT_CFG_PATH_REC).cfg  # mobile model
                if not os.path.exists(config['Global']['pretrained_model']):
                    model_dir = check_and_download_model(
                        MODEL_NAME_REC, DOWNLOAD_URL_REC)
            config['Global']['pretrained_model'] = model_dir
            config['Global']['character_dict_path'] = DEFAULT_DICT_PATH_REC
        else:
            if config['Architecture']['algorithm'] == 'SVTRv2_mobile':
                if not os.path.exists(config['Global']['pretrained_model']):
                    config['Global'][
                        'pretrained_model'] = check_and_download_model(
                            MODEL_NAME_REC, DOWNLOAD_URL_REC)
                config['Global']['character_dict_path'] = DEFAULT_DICT_PATH_REC
            elif config['Architecture']['algorithm'] == 'SVTRv2_server':
                if not os.path.exists(config['Global']['pretrained_model']):
                    config['Global'][
                        'pretrained_model'] = check_and_download_model(
                            MODEL_NAME_REC_SERVER, DOWNLOAD_URL_REC_SERVER)
                config['Global']['character_dict_path'] = DEFAULT_DICT_PATH_REC
        global_config = config['Global']
        self.cfg = config
        if global_config['pretrained_model'] is None:
            global_config[
                'pretrained_model'] = global_config['output_dir'] + '/best.pth'
        # build post process
        from openrec.modeling import build_model as build_rec_model
        from openrec.postprocess import build_post_process
        from openrec.preprocess import create_operators, transform
        self.transform = transform
        self.post_process_class = build_post_process(config['PostProcess'],
                                                     global_config)
        char_num = self.post_process_class.get_character_num()
        config['Architecture']['Decoder']['out_channels'] = char_num
        # print(char_num)
        self.model = build_rec_model(config['Architecture'])
        load_ckpt(self.model, config)

        # exit(0)
        self.device = set_device(global_config['device'], numId=numId)
        self.model.eval()
        replace_batchnorm(self.model.encoder)
        self.model.to(device=self.device)

        transforms, ratio_resize_flag = build_rec_process(self.cfg)
        global_config['infer_mode'] = True
        self.ops = create_operators(transforms, global_config)
        if ratio_resize_flag:
            ratio_resize = RatioRecTVReisze(cfg=self.cfg)
            self.ops.insert(-1, ratio_resize)

    def __call__(self,
                 img_path=None,
                 img_numpy_list=None,
                 img_numpy=None,
                 batch_num=1):
        """
        调用函数,处理输入图像,并返回识别结果。

        Args:
            img_path (str, optional): 图像文件的路径。默认为 None。
            img_numpy_list (list, optional): 包含多个图像 numpy 数组的列表。默认为 None。
            img_numpy (numpy.ndarray, optional): 单个图像的 numpy 数组。默认为 None。
            batch_num (int, optional): 每次处理的图像数量。默认为 1。

        Returns:
            list: 包含识别结果的列表,每个元素为一个字典,包含文件路径(如果有的话)、文本、分数和延迟时间。

        Raises:
            Exception: 如果没有提供图像路径或 numpy 数组,则引发异常。
        """

        if img_numpy is not None:
            img_numpy_list = [img_numpy]
            num_img = 1
        elif img_path is not None:
            img_path = get_image_file_list(img_path)
            num_img = len(img_path)
        elif img_numpy_list is not None:
            num_img = len(img_numpy_list)
        else:
            raise Exception('No input image path or numpy array.')
        results = []
        for start_idx in range(0, num_img, batch_num):
            batch_data = []
            batch_others = []
            batch_file_names = []

            max_width, max_height = 0, 0
            # Prepare batch data
            for img_idx in range(start_idx, min(start_idx + batch_num,
                                                num_img)):
                if img_numpy_list is not None:
                    img = img_numpy_list[img_idx]
                    data = {'image': img}
                elif img_path is not None:
                    file_name = img_path[img_idx]
                    with open(file_name, 'rb') as f:
                        img = f.read()
                        data = {'image': img}
                    data = self.transform(data, self.ops[:1])
                    batch_file_names.append(file_name)
                batch = self.transform(data, self.ops[1:])
                others = None
                if self.cfg['Architecture']['algorithm'] in [
                        'SAR', 'RobustScanner'
                ]:
                    valid_ratio = np.expand_dims(batch[-1], axis=0)
                    batch_others.append(valid_ratio)
                    # others = [torch.from_numpy(valid_ratio).to(device=self.device)]
                resized_image = batch[0]
                h, w = resized_image.shape[-2:]
                max_width = max(max_width, w)
                max_height = max(max_height, h)
                batch_data.append(batch[0])

            padded_batch_data = []
            for resized_image in batch_data:
                padded_image = np.zeros([1, 3, max_height, max_width],
                                        dtype=np.float32)
                h, w = resized_image.shape[-2:]

                # Apply padding (bottom-right padding)
                padded_image[:, :, :h, :
                             w] = resized_image  # 0 is typically used for padding
                padded_batch_data.append(padded_image)

            if batch_others:
                others = np.concatenate(batch_others, axis=0)
            else:
                others = None
            images = np.concatenate(padded_batch_data, axis=0)
            images = torch.from_numpy(images).to(device=self.device)

            with torch.no_grad():
                t_start = time.time()
                preds = self.model(images, others)
                t_cost = time.time() - t_start
            post_results = self.post_process_class(preds)

            for i, post_result in enumerate(post_results):
                if img_path is not None:
                    info = {
                        'file': batch_file_names[i],
                        'text': post_result[0],
                        'score': post_result[1],
                        'elapse': t_cost
                    }
                else:
                    info = {
                        'text': post_result[0],
                        'score': post_result[1],
                        'elapse': t_cost
                    }
                results.append(info)

        return results


def main(cfg):
    model = OpenRecognizer(cfg)

    save_res_path = cfg['Global']['output_dir']
    if not os.path.exists(save_res_path):
        os.makedirs(save_res_path)

    t_sum = 0
    sample_num = 0
    max_len = cfg['Global']['max_text_length']
    text_len_time = [0 for _ in range(max_len)]
    text_len_num = [0 for _ in range(max_len)]

    sample_num = 0
    with open(save_res_path + '/rec_results.txt', 'wb') as fout:
        for file in get_image_file_list(cfg['Global']['infer_img']):

            preds_result = model(img_path=file, batch_num=1)[0]

            rec_text = preds_result['text']
            score = preds_result['score']
            t_cost = preds_result['elapse']
            info = rec_text + '\t' + str(score)
            text_len_num[min(max_len - 1, len(rec_text))] += 1
            text_len_time[min(max_len - 1, len(rec_text))] += t_cost
            logger.info(
                f'{sample_num} {file}\t result: {info}, time cost: {t_cost}')
            otstr = file + '\t' + info + '\n'
            t_sum += t_cost
            fout.write(otstr.encode())
            sample_num += 1

    print(text_len_num)
    w_avg_t_cost = []
    for l_t_cost, l_num in zip(text_len_time, text_len_num):
        if l_num != 0:
            w_avg_t_cost.append(l_t_cost / l_num)
    print(w_avg_t_cost)
    w_avg_t_cost = sum(w_avg_t_cost) / len(w_avg_t_cost)

    logger.info(
        f'Sample num: {sample_num}, Weighted Avg time cost: {t_sum/sample_num}, Avg time cost: {w_avg_t_cost}'
    )
    logger.info('success!')


if __name__ == '__main__':
    FLAGS = ArgsParser().parse_args()
    cfg = Config(FLAGS.config)
    FLAGS = vars(FLAGS)
    opt = FLAGS.pop('opt')
    cfg.merge_dict(FLAGS)
    cfg.merge_dict(opt)
    main(cfg.cfg)