# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


# This code is modified from 
# https://github.com/lifeiteng/vall-e/blob/9c69096d603ce13174fb5cb025f185e2e9b36ac7/valle/data/input_strategies.py
import random
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from typing import Tuple, Type

from lhotse import CutSet
from lhotse.dataset.collation import collate_features
from lhotse.dataset.input_strategies import (
    ExecutorType,
    PrecomputedFeatures,
    _get_executor,
)
from lhotse.utils import fastcopy


class PromptedFeatures:
    def __init__(self, prompts, features):
        self.prompts = prompts
        self.features = features

    def to(self, device):
        return PromptedFeatures(
            self.prompts.to(device), self.features.to(device)
        )

    def sum(self):
        return self.features.sum()

    @property
    def ndim(self):
        return self.features.ndim

    @property
    def data(self):
        return (self.prompts, self.features)


class PromptedPrecomputedFeatures(PrecomputedFeatures):
    
    def __init__(
        self,
        dataset: str,
        cuts: CutSet,
        num_workers: int = 0,
        executor_type: Type[ExecutorType] = ThreadPoolExecutor,
    ) -> None:
        super().__init__(num_workers, executor_type)
        self.utt2neighbors = self._create_utt2neighbors(dataset, cuts)

    def __call__(
        self, cuts: CutSet
    ) -> Tuple[PromptedFeatures, PromptedFeatures]:
        features, features_lens = self._collate_features(cuts)
        prompts, prompts_lens = self._collate_prompts(cuts)
        return PromptedFeatures(prompts, features), PromptedFeatures(prompts_lens, features_lens)

    def _create_utt2neighbors(self, dataset, cuts):
        utt2neighbors = defaultdict(lambda: [])
        utt2cut = {cut.id: cut for cut in cuts}
        if dataset.lower() == "libritts":
            self._process_libritts_dataset(utt2neighbors, utt2cut, cuts)
        elif dataset.lower() == "ljspeech":
            self._process_ljspeech_dataset(utt2neighbors, utt2cut, cuts)
        else:
            raise ValueError("Unsupported dataset")
        return utt2neighbors

    def _process_libritts_dataset(self, utt2neighbors, utt2cut, cuts):
        speaker2utts = defaultdict(lambda: [])
        for cut in cuts:
            speaker = cut.supervisions[0].speaker
            speaker2utts[speaker].append(cut.id)

        for spk, uttids in speaker2utts.items():
            sorted_uttids = sorted(uttids)
            if len(sorted_uttids) == 1:
                utt2neighbors[sorted_uttids[0]].append(utt2cut[sorted_uttids[0]])
                continue

            utt2prevutt = dict(zip(sorted_uttids, [sorted_uttids[1]] + sorted_uttids[:-1]))
            utt2postutt = dict(zip(sorted_uttids[:-1], sorted_uttids[1:]))
            for utt in sorted_uttids:
                if utt in utt2prevutt:
                    utt2neighbors[utt].append(utt2cut[utt2prevutt[utt]])
                if utt in utt2postutt:
                    utt2neighbors[utt].append(utt2cut[utt2postutt[utt]])

    def _process_ljspeech_dataset(self, utt2neighbors, utt2cut, cuts):
        uttids = [cut.id for cut in cuts]
        if len(uttids) == 1:
            utt2neighbors[uttids[0]].append(utt2cut[uttids[0]])
            return

        utt2prevutt = dict(zip(uttids, [uttids[1]] + uttids[:-1]))
        utt2postutt = dict(zip(uttids[:-1], uttids[1:]))
        for utt in uttids:
            prevutt, postutt = utt2prevutt.get(utt), utt2postutt.get(utt)
            if prevutt and utt[:5] == prevutt[:5]:
                utt2neighbors[utt].append(utt2cut[prevutt])
            if postutt and utt[:5] == postutt[:5]:
                utt2neighbors[utt].append(utt2cut[postutt])

    def _collate_features(self, cuts):
        return collate_features(
            cuts, executor=_get_executor(self.num_workers, executor_type=self._executor_type)
        )

    def _collate_prompts(self, cuts):
        prompts_cuts = []
        for k, cut in enumerate(cuts):
            prompts_cut = random.choice(self.utt2neighbors[cut.id])
            prompts_cuts.append(fastcopy(prompts_cut, id=f"{cut.id}-{str(k)}"))

        mini_duration = min([cut.duration for cut in prompts_cuts] + [3.0])
        prompts_cuts = CutSet(
            cuts={k: cut for k, cut in enumerate(prompts_cuts)}
        ).truncate(max_duration=mini_duration, offset_type="random", preserve_id=False)

        return collate_features(
            prompts_cuts, executor=_get_executor(self.num_workers, executor_type=self._executor_type)
        )