# --------------------------------------------------------
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
# Github source: https://github.com/mbzuai-nlp/ArTST
# Based on speecht5, fairseq and espnet code bases
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
# --------------------------------------------------------

import itertools
import logging
import os
import mmap
from typing import Any, List, Optional

import numpy as np

import torch
torch.set_printoptions(profile="full")
import torch.nn.functional as F
from fairseq.data import data_utils, Dictionary
from fairseq.data.fairseq_dataset import FairseqDataset

logger = logging.getLogger(__name__)


def load_audio(manifest_path, max_keep, min_keep):
    n_long, n_short = 0, 0
    names, inds, sizes = [], [], []
    with open(manifest_path) as f:
        root = f.readline().strip()
        for ind, line in enumerate(f):
            items = line.strip().split("\t")
            assert len(items) >= 2, line
            sz = int(items[1])
            if min_keep is not None and sz < min_keep:
                n_short += 1
            elif max_keep is not None and sz > max_keep:
                n_long += 1
            else:
                names.append(items[0])
                inds.append(ind)
                sizes.append(sz)
    tot = ind + 1
    logger.info(
        (
            f"max_keep={max_keep}, min_keep={min_keep}, "
            f"loaded {len(names)}, skipped {n_short} short and {n_long} long, "
            f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}"
        )
    )
    return root, names, inds, tot, sizes


def load_label(label_path, inds, tot):
    with open(label_path) as f:
        labels = [line.rstrip() for line in f]
        assert (
            len(labels) == tot
        ), f"number of labels does not match ({len(labels)} != {tot})"
        labels = [labels[i] for i in inds]
    return labels


def load_label_offset(label_path, inds, tot):
    with open(label_path) as f:
        # Hawau:
        # changed line length reading as it's incorrect
        code_lengths = [len(line.encode("utf-8")) for line in f] #original
        # code_lengths = [len(line) for line in f] #fix
        assert (
            len(code_lengths) == tot
        ), f"number of labels does not match ({len(code_lengths)} != {tot})"
        offsets = list(itertools.accumulate([0] + code_lengths))
        offsets = [(offsets[i], offsets[i + 1]) for i in inds]
    return offsets


class SpeechToTextDataset(FairseqDataset):
    def __init__(
        self,
        manifest_path: str,
        sample_rate: float,
        label_paths: List[str],
        label_processors: Optional[List[Any]] = None,
        max_keep_sample_size: Optional[int] = None,
        min_keep_sample_size: Optional[int] = None,
        shuffle: bool = True,
        normalize: bool = False,
        store_labels: bool = True,
        tgt_dict: Optional[Dictionary] = None,
        tokenizer = None,
    ):
        self.audio_root, self.audio_names, inds, tot, self.wav_sizes = load_audio(
            manifest_path, max_keep_sample_size, min_keep_sample_size
        )
       
        self.sample_rate = sample_rate
        self.shuffle = shuffle
        self.tgt_dict = tgt_dict
        self.tokenizer = tokenizer

        self.num_labels = len(label_paths)
        self.label_processors = label_processors
        self.store_labels = store_labels

        if store_labels:
            self.label_list = [load_label(p, inds, tot) for p in label_paths]
            logger.info(f"label_list: {self.label_list}")
        else:
            self.label_paths = label_paths
            self.label_offsets_list = [
                load_label_offset(p, inds, tot) for p in label_paths
            ]
            # logger.info(f"label_offsets_list: {self.label_offsets_list}")
        assert label_processors is None or len(label_processors) == self.num_labels

        self.normalize = normalize
        logger.info(
            f"normalize={normalize}"
        )

    def get_audio(self, index):
        import soundfile as sf
        # Hawau:
        # logger.info(f"loaded_audio: {self.audio_names[index]}")
        wav_path = os.path.join(self.audio_root, self.audio_names[index])
        wav, cur_sample_rate = sf.read(wav_path)
        wav = torch.from_numpy(wav).float()
        wav = self.postprocess(wav, cur_sample_rate)
        return wav

    def get_label(self, index, label_idx):
        if self.store_labels:
            label = self.label_list[label_idx][index]
        else:
            # list slicing method
            # with open(self.label_paths[label_idx]) as f:
            #     offset_s, offset_e = self.label_offsets_list[label_idx][index]
            #     # Hawau:
            #     # f.seek(offset_s)
            #     # label = f.read(offset_e - offset_s)
            #     label = f.read()[offset_s : offset_e]
            # Hawau:
            # mmap method
            with open(self.label_paths[label_idx], encoding='utf-8') as f:
                with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mm:
                    offset_s, offset_e = self.label_offsets_list[label_idx][index]
                    label = mm[offset_s:offset_e].decode("utf-8")

                
        # Hawau:
        # logger.info(f"loaded_label: {label}")
        if self.tokenizer is not None:
            label = self.tokenizer.encode(label)

        if self.label_processors is not None:
            label = self.label_processors[label_idx](label)
        # logger.info(f"processed_label: {label}")
        return label

    def get_labels(self, index):
        return [self.get_label(index, i) for i in range(self.num_labels)]

    def __getitem__(self, index):
        wav = self.get_audio(index)
        labels = self.get_labels(index)
        return {"id": index, "source": wav, "label_list": labels}

    def __len__(self):
        return len(self.wav_sizes)

    def collater(self, samples):
        samples = [s for s in samples if s["source"] is not None]
        if len(samples) == 0:
            return {}

        audios = [s["source"] for s in samples]
        audio_sizes = [len(s) for s in audios]

        audio_size = max(audio_sizes)
        collated_audios, padding_mask = self.collater_audio(
            audios, audio_size
        )

        targets_by_label = [
            [s["label_list"][i] for s in samples] for i in range(self.num_labels)
        ]
        targets_list, lengths_list, ntokens_list = self.collater_label(targets_by_label)

        # Hawau:
        # logger.info(f'targets_list: {targets_list}')
    

        decoder_label = [
            torch.cat((targets_list[0][i, :lengths_list[0][i]], torch.tensor([self.tgt_dict.eos()])), 0).long()
            for i in range(targets_list[0].size(0))
        ]

        decoder_target = data_utils.collate_tokens(
            decoder_label,
            self.tgt_dict.pad(),
            self.tgt_dict.eos(),
            left_pad=False,
            move_eos_to_beginning=False,
        )
        decoder_target_lengths = torch.tensor(
            [x.size(0) for x in decoder_label], dtype=torch.long
        )
        prev_output_tokens = data_utils.collate_tokens(
            decoder_label,
            self.tgt_dict.pad(),
            self.tgt_dict.eos(),
            left_pad=False,
            move_eos_to_beginning=True,
        )

        net_input = {
            "source": collated_audios, 
            "padding_mask": padding_mask,
            "prev_output_tokens": prev_output_tokens,
            "task_name": "s2t",
        }
        batch = {
            "id": torch.LongTensor([s["id"] for s in samples]),
            "net_input": net_input,
            "target": decoder_target,
            "target_lengths": decoder_target_lengths,
            "task_name": "s2t",
            "ntokens": ntokens_list[0]
        }

        return batch

    def collater_audio(self, audios, audio_size):
        collated_audios = audios[0].new_zeros(len(audios), audio_size)
        padding_mask = (
            torch.BoolTensor(collated_audios.shape).fill_(False)
        )
        for i, audio in enumerate(audios):
            diff = len(audio) - audio_size
            if diff == 0:
                collated_audios[i] = audio
            elif diff < 0:
                collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)])
                padding_mask[i, diff:] = True
            else:
                raise Exception("Diff should not be larger than 0")
        return collated_audios, padding_mask

    def collater_seq_label(self, targets, pad):
        lengths = torch.LongTensor([len(t) for t in targets])
        ntokens = lengths.sum().item()
        targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
        return targets, lengths, ntokens

    def collater_label(self, targets_by_label):
        targets_list, lengths_list, ntokens_list = [], [], []
        itr = zip(targets_by_label, [self.tgt_dict.pad()])
        
        for targets, pad in itr:
            # Hawau:
            # logger.info(f'targets: {targets}')
            targets, lengths, ntokens = self.collater_seq_label(targets, pad)
            targets_list.append(targets)
            lengths_list.append(lengths)
            ntokens_list.append(ntokens)
        return targets_list, lengths_list, ntokens_list

    def num_tokens(self, index):
        return self.size(index)

    def size(self, index):
        return self.wav_sizes[index]

    @property
    def sizes(self):
        return np.array(self.wav_sizes)

    def ordered_indices(self):
        if self.shuffle:
            order = [np.random.permutation(len(self))]
        else:
            order = [np.arange(len(self))]

        order.append(self.wav_sizes)
        return np.lexsort(order)[::-1]

    def postprocess(self, wav, cur_sample_rate):
        if wav.dim() == 2:
            wav = wav.mean(-1)
        assert wav.dim() == 1, wav.dim()

        if cur_sample_rate != self.sample_rate:
            raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")

        if self.normalize:
            with torch.no_grad():
                wav = F.layer_norm(wav, wav.shape)
        return wav