import os
import cv2
import json
import random
import glob
import torch
import einops
import numpy as np
import datetime
import torchvision

import safetensors.torch as sf
from PIL import Image


def min_resize(x, m):
    if x.shape[0] < x.shape[1]:
        s0 = m
        s1 = int(float(m) / float(x.shape[0]) * float(x.shape[1]))
    else:
        s0 = int(float(m) / float(x.shape[1]) * float(x.shape[0]))
        s1 = m
    new_max = max(s1, s0)
    raw_max = max(x.shape[0], x.shape[1])
    if new_max < raw_max:
        interpolation = cv2.INTER_AREA
    else:
        interpolation = cv2.INTER_LANCZOS4
    y = cv2.resize(x, (s1, s0), interpolation=interpolation)
    return y


def d_resize(x, y):
    H, W, C = y.shape
    new_min = min(H, W)
    raw_min = min(x.shape[0], x.shape[1])
    if new_min < raw_min:
        interpolation = cv2.INTER_AREA
    else:
        interpolation = cv2.INTER_LANCZOS4
    y = cv2.resize(x, (W, H), interpolation=interpolation)
    return y


def resize_and_center_crop(image, target_width, target_height):
    if target_height == image.shape[0] and target_width == image.shape[1]:
        return image

    pil_image = Image.fromarray(image)
    original_width, original_height = pil_image.size
    scale_factor = max(target_width / original_width, target_height / original_height)
    resized_width = int(round(original_width * scale_factor))
    resized_height = int(round(original_height * scale_factor))
    resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS)
    left = (resized_width - target_width) / 2
    top = (resized_height - target_height) / 2
    right = (resized_width + target_width) / 2
    bottom = (resized_height + target_height) / 2
    cropped_image = resized_image.crop((left, top, right, bottom))
    return np.array(cropped_image)


def resize_and_center_crop_pytorch(image, target_width, target_height):
    B, C, H, W = image.shape

    if H == target_height and W == target_width:
        return image

    scale_factor = max(target_width / W, target_height / H)
    resized_width = int(round(W * scale_factor))
    resized_height = int(round(H * scale_factor))

    resized = torch.nn.functional.interpolate(image, size=(resized_height, resized_width), mode='bilinear', align_corners=False)

    top = (resized_height - target_height) // 2
    left = (resized_width - target_width) // 2
    cropped = resized[:, :, top:top + target_height, left:left + target_width]

    return cropped


def resize_without_crop(image, target_width, target_height):
    if target_height == image.shape[0] and target_width == image.shape[1]:
        return image

    pil_image = Image.fromarray(image)
    resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
    return np.array(resized_image)


def just_crop(image, w, h):
    if h == image.shape[0] and w == image.shape[1]:
        return image

    original_height, original_width = image.shape[:2]
    k = min(original_height / h, original_width / w)
    new_width = int(round(w * k))
    new_height = int(round(h * k))
    x_start = (original_width - new_width) // 2
    y_start = (original_height - new_height) // 2
    cropped_image = image[y_start:y_start + new_height, x_start:x_start + new_width]
    return cropped_image


def write_to_json(data, file_path):
    temp_file_path = file_path + ".tmp"
    with open(temp_file_path, 'wt', encoding='utf-8') as temp_file:
        json.dump(data, temp_file, indent=4)
    os.replace(temp_file_path, file_path)
    return


def read_from_json(file_path):
    with open(file_path, 'rt', encoding='utf-8') as file:
        data = json.load(file)
    return data


def get_active_parameters(m):
    return {k: v for k, v in m.named_parameters() if v.requires_grad}


def cast_training_params(m, dtype=torch.float32):
    result = {}
    for n, param in m.named_parameters():
        if param.requires_grad:
            param.data = param.to(dtype)
            result[n] = param
    return result


def separate_lora_AB(parameters, B_patterns=None):
    parameters_normal = {}
    parameters_B = {}

    if B_patterns is None:
        B_patterns = ['.lora_B.', '__zero__']

    for k, v in parameters.items():
        if any(B_pattern in k for B_pattern in B_patterns):
            parameters_B[k] = v
        else:
            parameters_normal[k] = v

    return parameters_normal, parameters_B


def set_attr_recursive(obj, attr, value):
    attrs = attr.split(".")
    for name in attrs[:-1]:
        obj = getattr(obj, name)
    setattr(obj, attrs[-1], value)
    return


def print_tensor_list_size(tensors):
    total_size = 0
    total_elements = 0

    if isinstance(tensors, dict):
        tensors = tensors.values()

    for tensor in tensors:
        total_size += tensor.nelement() * tensor.element_size()
        total_elements += tensor.nelement()

    total_size_MB = total_size / (1024 ** 2)
    total_elements_B = total_elements / 1e9

    print(f"Total number of tensors: {len(tensors)}")
    print(f"Total size of tensors: {total_size_MB:.2f} MB")
    print(f"Total number of parameters: {total_elements_B:.3f} billion")
    return


@torch.no_grad()
def batch_mixture(a, b=None, probability_a=0.5, mask_a=None):
    batch_size = a.size(0)

    if b is None:
        b = torch.zeros_like(a)

    if mask_a is None:
        mask_a = torch.rand(batch_size) < probability_a

    mask_a = mask_a.to(a.device)
    mask_a = mask_a.reshape((batch_size,) + (1,) * (a.dim() - 1))
    result = torch.where(mask_a, a, b)
    return result


@torch.no_grad()
def zero_module(module):
    for p in module.parameters():
        p.detach().zero_()
    return module


@torch.no_grad()
def supress_lower_channels(m, k, alpha=0.01):
    data = m.weight.data.clone()

    assert int(data.shape[1]) >= k

    data[:, :k] = data[:, :k] * alpha
    m.weight.data = data.contiguous().clone()
    return m


def freeze_module(m):
    if not hasattr(m, '_forward_inside_frozen_module'):
        m._forward_inside_frozen_module = m.forward
    m.requires_grad_(False)
    m.forward = torch.no_grad()(m.forward)
    return m


def get_latest_safetensors(folder_path):
    safetensors_files = glob.glob(os.path.join(folder_path, '*.safetensors'))

    if not safetensors_files:
        raise ValueError('No file to resume!')

    latest_file = max(safetensors_files, key=os.path.getmtime)
    latest_file = os.path.abspath(os.path.realpath(latest_file))
    return latest_file


def generate_random_prompt_from_tags(tags_str, min_length=3, max_length=32):
    tags = tags_str.split(', ')
    tags = random.sample(tags, k=min(random.randint(min_length, max_length), len(tags)))
    prompt = ', '.join(tags)
    return prompt


def interpolate_numbers(a, b, n, round_to_int=False, gamma=1.0):
    numbers = a + (b - a) * (np.linspace(0, 1, n) ** gamma)
    if round_to_int:
        numbers = np.round(numbers).astype(int)
    return numbers.tolist()


def uniform_random_by_intervals(inclusive, exclusive, n, round_to_int=False):
    edges = np.linspace(0, 1, n + 1)
    points = np.random.uniform(edges[:-1], edges[1:])
    numbers = inclusive + (exclusive - inclusive) * points
    if round_to_int:
        numbers = np.round(numbers).astype(int)
    return numbers.tolist()


def soft_append_bcthw(history, current, overlap=0):
    if overlap <= 0:
        return torch.cat([history, current], dim=2)

    assert history.shape[2] >= overlap, f"History length ({history.shape[2]}) must be >= overlap ({overlap})"
    assert current.shape[2] >= overlap, f"Current length ({current.shape[2]}) must be >= overlap ({overlap})"
    
    weights = torch.linspace(1, 0, overlap, dtype=history.dtype, device=history.device).view(1, 1, -1, 1, 1)
    blended = weights * history[:, :, -overlap:] + (1 - weights) * current[:, :, :overlap]
    output = torch.cat([history[:, :, :-overlap], blended, current[:, :, overlap:]], dim=2)

    return output.to(history)


def save_bcthw_as_mp4(x, output_filename, fps=10):
    b, c, t, h, w = x.shape

    per_row = b
    for p in [6, 5, 4, 3, 2]:
        if b % p == 0:
            per_row = p
            break

    os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
    x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
    x = x.detach().cpu().to(torch.uint8)
    x = einops.rearrange(x, '(m n) c t h w -> t (m h) (n w) c', n=per_row)
    torchvision.io.write_video(output_filename, x, fps=fps, video_codec='libx264', options={'crf': '0'})
    return x


def save_bcthw_as_png(x, output_filename):
    os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
    x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
    x = x.detach().cpu().to(torch.uint8)
    x = einops.rearrange(x, 'b c t h w -> c (b h) (t w)')
    torchvision.io.write_png(x, output_filename)
    return output_filename


def save_bchw_as_png(x, output_filename):
    os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
    x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
    x = x.detach().cpu().to(torch.uint8)
    x = einops.rearrange(x, 'b c h w -> c h (b w)')
    torchvision.io.write_png(x, output_filename)
    return output_filename


def add_tensors_with_padding(tensor1, tensor2):
    if tensor1.shape == tensor2.shape:
        return tensor1 + tensor2

    shape1 = tensor1.shape
    shape2 = tensor2.shape

    new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2))

    padded_tensor1 = torch.zeros(new_shape)
    padded_tensor2 = torch.zeros(new_shape)

    padded_tensor1[tuple(slice(0, s) for s in shape1)] = tensor1
    padded_tensor2[tuple(slice(0, s) for s in shape2)] = tensor2

    result = padded_tensor1 + padded_tensor2
    return result


def print_free_mem():
    torch.cuda.empty_cache()
    free_mem, total_mem = torch.cuda.mem_get_info(0)
    free_mem_mb = free_mem / (1024 ** 2)
    total_mem_mb = total_mem / (1024 ** 2)
    print(f"Free memory: {free_mem_mb:.2f} MB")
    print(f"Total memory: {total_mem_mb:.2f} MB")
    return


def print_gpu_parameters(device, state_dict, log_count=1):
    summary = {"device": device, "keys_count": len(state_dict)}

    logged_params = {}
    for i, (key, tensor) in enumerate(state_dict.items()):
        if i >= log_count:
            break
        logged_params[key] = tensor.flatten()[:3].tolist()

    summary["params"] = logged_params

    print(str(summary))
    return


def visualize_txt_as_img(width, height, text, font_path='font/DejaVuSans.ttf', size=18):
    from PIL import Image, ImageDraw, ImageFont

    txt = Image.new("RGB", (width, height), color="white")
    draw = ImageDraw.Draw(txt)
    font = ImageFont.truetype(font_path, size=size)

    if text == '':
        return np.array(txt)

    # Split text into lines that fit within the image width
    lines = []
    words = text.split()
    current_line = words[0]

    for word in words[1:]:
        line_with_word = f"{current_line} {word}"
        if draw.textbbox((0, 0), line_with_word, font=font)[2] <= width:
            current_line = line_with_word
        else:
            lines.append(current_line)
            current_line = word

    lines.append(current_line)

    # Draw the text line by line
    y = 0
    line_height = draw.textbbox((0, 0), "A", font=font)[3]

    for line in lines:
        if y + line_height > height:
            break  # stop drawing if the next line will be outside the image
        draw.text((0, y), line, fill="black", font=font)
        y += line_height

    return np.array(txt)


def blue_mark(x):
    x = x.copy()
    c = x[:, :, 2]
    b = cv2.blur(c, (9, 9))
    x[:, :, 2] = ((c - b) * 16.0 + b).clip(-1, 1)
    return x


def green_mark(x):
    x = x.copy()
    x[:, :, 2] = -1
    x[:, :, 0] = -1
    return x


def frame_mark(x):
    x = x.copy()
    x[:64] = -1
    x[-64:] = -1
    x[:, :8] = 1
    x[:, -8:] = 1
    return x


@torch.inference_mode()
def pytorch2numpy(imgs):
    results = []
    for x in imgs:
        y = x.movedim(0, -1)
        y = y * 127.5 + 127.5
        y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
        results.append(y)
    return results


@torch.inference_mode()
def numpy2pytorch(imgs):
    h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0
    h = h.movedim(-1, 1)
    return h


@torch.no_grad()
def duplicate_prefix_to_suffix(x, count, zero_out=False):
    if zero_out:
        return torch.cat([x, torch.zeros_like(x[:count])], dim=0)
    else:
        return torch.cat([x, x[:count]], dim=0)


def weighted_mse(a, b, weight):
    return torch.mean(weight.float() * (a.float() - b.float()) ** 2)


def clamped_linear_interpolation(x, x_min, y_min, x_max, y_max, sigma=1.0):
    x = (x - x_min) / (x_max - x_min)
    x = max(0.0, min(x, 1.0))
    x = x ** sigma
    return y_min + x * (y_max - y_min)


def expand_to_dims(x, target_dims):
    return x.view(*x.shape, *([1] * max(0, target_dims - x.dim())))


def repeat_to_batch_size(tensor: torch.Tensor, batch_size: int):
    if tensor is None:
        return None

    first_dim = tensor.shape[0]

    if first_dim == batch_size:
        return tensor

    if batch_size % first_dim != 0:
        raise ValueError(f"Cannot evenly repeat first dim {first_dim} to match batch_size {batch_size}.")

    repeat_times = batch_size // first_dim

    return tensor.repeat(repeat_times, *[1] * (tensor.dim() - 1))


def dim5(x):
    return expand_to_dims(x, 5)


def dim4(x):
    return expand_to_dims(x, 4)


def dim3(x):
    return expand_to_dims(x, 3)


def crop_or_pad_yield_mask(x, length):
    B, F, C = x.shape
    device = x.device
    dtype = x.dtype

    if F < length:
        y = torch.zeros((B, length, C), dtype=dtype, device=device)
        mask = torch.zeros((B, length), dtype=torch.bool, device=device)
        y[:, :F, :] = x
        mask[:, :F] = True
        return y, mask

    return x[:, :length, :], torch.ones((B, length), dtype=torch.bool, device=device)


def extend_dim(x, dim, minimal_length, zero_pad=False):
    original_length = int(x.shape[dim])

    if original_length >= minimal_length:
        return x

    if zero_pad:
        padding_shape = list(x.shape)
        padding_shape[dim] = minimal_length - original_length
        padding = torch.zeros(padding_shape, dtype=x.dtype, device=x.device)
    else:
        idx = (slice(None),) * dim + (slice(-1, None),) + (slice(None),) * (len(x.shape) - dim - 1)
        last_element = x[idx]
        padding = last_element.repeat_interleave(minimal_length - original_length, dim=dim)

    return torch.cat([x, padding], dim=dim)


def lazy_positional_encoding(t, repeats=None):
    if not isinstance(t, list):
        t = [t]

    from diffusers.models.embeddings import get_timestep_embedding

    te = torch.tensor(t)
    te = get_timestep_embedding(timesteps=te, embedding_dim=256, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=1.0)

    if repeats is None:
        return te

    te = te[:, None, :].expand(-1, repeats, -1)

    return te


def state_dict_offset_merge(A, B, C=None):
    result = {}
    keys = A.keys()

    for key in keys:
        A_value = A[key]
        B_value = B[key].to(A_value)

        if C is None:
            result[key] = A_value + B_value
        else:
            C_value = C[key].to(A_value)
            result[key] = A_value + B_value - C_value

    return result


def state_dict_weighted_merge(state_dicts, weights):
    if len(state_dicts) != len(weights):
        raise ValueError("Number of state dictionaries must match number of weights")

    if not state_dicts:
        return {}

    total_weight = sum(weights)

    if total_weight == 0:
        raise ValueError("Sum of weights cannot be zero")

    normalized_weights = [w / total_weight for w in weights]

    keys = state_dicts[0].keys()
    result = {}

    for key in keys:
        result[key] = state_dicts[0][key] * normalized_weights[0]

        for i in range(1, len(state_dicts)):
            state_dict_value = state_dicts[i][key].to(result[key])
            result[key] += state_dict_value * normalized_weights[i]

    return result


def group_files_by_folder(all_files):
    grouped_files = {}

    for file in all_files:
        folder_name = os.path.basename(os.path.dirname(file))
        if folder_name not in grouped_files:
            grouped_files[folder_name] = []
        grouped_files[folder_name].append(file)

    list_of_lists = list(grouped_files.values())
    return list_of_lists


def generate_timestamp():
    now = datetime.datetime.now()
    timestamp = now.strftime('%y%m%d_%H%M%S')
    milliseconds = f"{int(now.microsecond / 1000):03d}"
    random_number = random.randint(0, 9999)
    return f"{timestamp}_{milliseconds}_{random_number}"


def write_PIL_image_with_png_info(image, metadata, path):
    from PIL.PngImagePlugin import PngInfo

    png_info = PngInfo()
    for key, value in metadata.items():
        png_info.add_text(key, value)

    image.save(path, "PNG", pnginfo=png_info)
    return image


def torch_safe_save(content, path):
    torch.save(content, path + '_tmp')
    os.replace(path + '_tmp', path)
    return path


def move_optimizer_to_device(optimizer, device):
    for state in optimizer.state.values():
        for k, v in state.items():
            if isinstance(v, torch.Tensor):
                state[k] = v.to(device)