# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Processor class for Phi4Multimodal
"""

from typing import Optional, Union, List, Tuple

import numpy as np

from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
from transformers.image_processing_utils import BatchFeature
from transformers.utils import TensorType, is_torch_available, logging


if is_torch_available():
    import torch


logger = logging.get_logger(__name__)

AudioInput = Union[
    np.ndarray, "torch.Tensor", List[np.ndarray], Tuple[np.ndarray], List["torch.Tensor"], Tuple["torch.Tensor"]  # noqa: F821
]


# TODO: @eustlb, remove this once #36603 is merged.
def speechlib_mel(sample_rate, n_fft, n_mels, fmin=None, fmax=None):
    """Create a Mel filter-bank the same as SpeechLib FbankFC.

    Args:
        sample_rate (int): Sample rate in Hz. number > 0 [scalar]
        n_fft (int): FFT size. int > 0 [scalar]
        n_mel (int): Mel filter size. int > 0 [scalar]
        fmin (float): lowest frequency (in Hz). If None use 0.0.
            float >= 0 [scalar]
        fmax: highest frequency (in Hz). If None use sample_rate / 2.
            float >= 0 [scalar]

    Returns
        out (numpy.ndarray): Mel transform matrix
            [shape=(n_mels, 1 + n_fft/2)]
    """

    bank_width = int(n_fft // 2 + 1)
    if fmax is None:
        fmax = sample_rate / 2
    if fmin is None:
        fmin = 0
    assert fmin >= 0, "fmin cannot be negtive"
    assert fmin < fmax <= sample_rate / 2, "fmax must be between (fmin, samplerate / 2]"

    def mel(f):
        return 1127.0 * np.log(1.0 + f / 700.0)

    def bin2mel(fft_bin):
        return 1127.0 * np.log(1.0 + fft_bin * sample_rate / (n_fft * 700.0))

    def f2bin(f):
        return int((f * n_fft / sample_rate) + 0.5)

    # Spec 1: FFT bin range [f2bin(fmin) + 1, f2bin(fmax) - 1]
    klo = f2bin(fmin) + 1
    khi = f2bin(fmax)

    khi = max(khi, klo)

    # Spec 2: SpeechLib uses trianges in Mel space
    mlo = mel(fmin)
    mhi = mel(fmax)
    m_centers = np.linspace(mlo, mhi, n_mels + 2)
    ms = (mhi - mlo) / (n_mels + 1)

    matrix = np.zeros((n_mels, bank_width), dtype=np.float32)
    for m in range(0, n_mels):
        left = m_centers[m]
        center = m_centers[m + 1]
        right = m_centers[m + 2]
        for fft_bin in range(klo, khi):
            mbin = bin2mel(fft_bin)
            if left < mbin < right:
                matrix[m, fft_bin] = 1.0 - abs(center - mbin) / ms

    return matrix


class Phi4MultimodalFeatureExtractor(SequenceFeatureExtractor):
    model_input_names = ["audio_input_features", "audio_embed_sizes", "audio_attention_mask"]

    def __init__(
        self,
        feature_size: int = 80,
        sampling_rate: int = 16000,
        hop_length: int = 160,
        n_fft: int = 512,
        win_length: int = 400,
        preemphasis: float = 0.97,
        padding_value: float = 0.0,
        audio_compression_rate: int = 8,
        audio_downsample_rate: int = 1,
        audio_feat_stride: int = 1,
        mel_min_frequency: float = 0,
        mel_max_frequency: float = 7690,
        **kwargs,
    ):
        super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)

        self.hop_length = hop_length
        self.n_fft = n_fft
        self.win_length = win_length
        self.preemphasis = preemphasis
        self.padding_value = padding_value
        self.audio_compression_rate = audio_compression_rate
        self.audio_downsample_rate = audio_downsample_rate
        self.audio_feat_stride = audio_feat_stride

        # TODO: @eustlb, uncomment and remove speechlib_mel once #36603 is merged.
        # self.mel_filters = mel_filter_bank(
        #     num_frequency_bins=self.n_fft // 2 + 1,
        #     num_mel_filters=self.feature_size,
        #     min_frequency=mel_min_frequency,
        #     max_frequency=mel_max_frequency,
        #     sampling_rate=self.sampling_rate,
        #     triangularize_in_mel_space=True,
        #     mel_scale="kaldi",
        # )
        self.mel_filters = speechlib_mel(
            self.sampling_rate, self.n_fft, self.feature_size, mel_min_frequency, mel_max_frequency
        ).T

    def __call__(
        self,
        raw_speech: AudioInput,
        sampling_rate: Optional[int] = None,
        pad_to_multiple_of: Optional[int] = None,
        padding: Optional[str] = "longest",
        max_length: Optional[int] = None,
        truncation: bool = False,
        return_tensors: Optional[Union[str, TensorType]] = None,
        return_attention_mask: Optional[bool] = True,
        device: Optional[str] = "cpu",
        **kwargs,
    ) -> BatchFeature:
        """
        Main method to featurize and prepare for the model one or several audio sequence(s). Implementation uses PyTorch for
        the STFT computation if available, otherwise a slower NumPy based one.

        Args:
            raw_speech (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
                The sequence or batch of sequences to be processed. Each sequence can be a numpy array or PyTorch tensor.
                For batched inputs, sequences can be a list of numpy arrays or PyTorch tensors, or a single numpy array or
                PyTorch tensor with first dimension being the batch size.
            sampling_rate (`int`, *optional*):
                The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass
                `sampling_rate` at the forward call to prevent silent errors.
            pad_to_multiple_of (`int`, *optional*, defaults to None):
                If set will pad the sequence to a multiple of the provided value.
            padding (`str`, *optional*, defaults to "longest"):
                Padding strategy. Can be "longest" to pad to the longest sequence in the batch, or a specific length.
            max_length (`int`, *optional*):
                Maximum length of the returned list and optionally padding length.
            truncation (`bool`, *optional*, defaults to False):
                Activates truncation to cut input sequences longer than *max_length* to *max_length*.
            return_tensors (`str` or [`~utils.TensorType`], *optional*):
                If set, will return tensors instead of numpy arrays. Acceptable values are:
                - `'pt'`: Return PyTorch `torch.Tensor` objects.
                - `'np'`: Return Numpy `np.ndarray` objects.
                - `'tf'`: Return TensorFlow `tf.constant` objects.
            return_attention_mask (`bool`, *optional*, defaults to `True`):
                Whether to return the extracted audio input features' attention mask.
            device (`str`, *optional*, defaults to "cpu"):
                Specifies the device for computation of the audio features. (e.g., "cpu", "cuda")

        Returns:
            [`BatchFeature`]: A [`BatchFeature`] with the following fields:
                - **audio_input_features** -- Audio features extracted from the raw audio input, shape (batch_size, max_feature_length, feature_size).
                - **audio_lengths** -- Length of each audio sample in the batch, shape (batch_size,).
                - **audio_attention_mask** -- Attention mask for the audio input, shape (batch_size, max_feature_length).
                If `return_tensors` is not specified, the fields will be PyTorch tensors if PyTorch is available, otherwise NumPy arrays.
        """
        if sampling_rate is not None:
            if sampling_rate != self.sampling_rate:
                raise ValueError(
                    f"The model corresponding to this feature extractor: {self.__class__.__name__} was trained using a"
                    f" sampling rate of {self.sampling_rate}. Please make sure that the provided `raw_speech` input"
                    f" was sampled with {self.sampling_rate} and not {sampling_rate}."
                )
        else:
            logger.warning(
                f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
                "Failing to do so can result in silent errors that might be hard to debug."
            )

        # Convert to torch tensor
        if isinstance(raw_speech, np.ndarray):
            raw_speech = torch.tensor(raw_speech)
        elif isinstance(raw_speech, (list, tuple)) and isinstance(raw_speech[0], np.ndarray):
            raw_speech = [torch.tensor(speech) for speech in raw_speech]

        is_batched_torch = isinstance(raw_speech, torch.Tensor) and len(raw_speech.shape) > 1
        if is_batched_torch and len(raw_speech.shape) > 2:
            logger.warning(
                f"Only mono-channel audio is supported for input to {self.__class__.__name__}. "
                "We will take the mean of the channels to convert to mono."
            )
            raw_speech = raw_speech.mean(-1)

        is_batched_sequence = isinstance(raw_speech, (list, tuple))
        if is_batched_sequence:
            for speech in raw_speech:
                if len(speech.shape) > 1:
                    logger.warning(
                        f"Only mono-channel audio is supported for input to {self.__class__.__name__}. "
                        "We will take the mean of the channels to convert to mono."
                    )
                    speech = speech.mean(-1)

        if is_batched_torch or is_batched_sequence:
            raw_speech = [speech[:, None].to(torch.float32) for speech in raw_speech]
        else:
            raw_speech = [raw_speech[:, None].to(torch.float32)]

        audio_lengths = [len(speech) for speech in raw_speech]

        # convert into correct format for padding
        batched_speech = BatchFeature(data={"audio_input_features": raw_speech, "audio_lengths": audio_lengths})
        padded_inputs = self.pad(
            batched_speech,
            padding=padding,
            max_length=max_length,
            truncation=truncation,
            pad_to_multiple_of=pad_to_multiple_of,
            return_tensors="pt",
        )
        input_features = padded_inputs.audio_input_features.squeeze(-1)
        audio_lengths = padded_inputs.audio_lengths

        input_features = self._torch_extract_fbank_features(input_features, audio_lengths, device)

        feature_lengths = (audio_lengths - self.win_length) // self.hop_length + 1
        feature_lengths = feature_lengths * self.audio_feat_stride
        audio_embed_sizes = self._compute_audio_embed_size(feature_lengths)

        feature_attention_mask = (
            torch.arange(0, feature_lengths.max()) if is_torch_available() else np.arange(0, feature_lengths.max())
        )
        feature_attention_mask = (
            feature_attention_mask[None, :] < feature_lengths[:, None] if len(feature_lengths) > 1 else None
        )

        data = {
            "audio_input_features": input_features,
            "audio_embed_sizes": audio_embed_sizes,
        }
        if feature_attention_mask is not None and return_attention_mask:
            data["audio_attention_mask"] = feature_attention_mask

        return BatchFeature(data=data, tensor_type=return_tensors)

    # TODO; @eustlb, move this to audio_utils in a general spectogram_batch function that handles torch and numpy
    def _torch_extract_fbank_features(
        self, waveform: "torch.FloatTensor", audio_lengths: "torch.Tensor", device: str = "cpu"
    ) -> "torch.FloatTensor":
        """
        Compute the log mel-scaled spectrogram of batched waveforms using PyTorch's FFT implementation.

        Args:
            waveform (torch.FloatTensor` of shape `(batch_size, max_audio_length)`):
                The batched waveforms.
            audio_lengths (`torch.Tensor` of shape `(batch_size,)`):
                The lengths of the waveforms along the max_audio_length dimension.
            device (`str`, *optional*, defaults to "cpu"):
                The device to run the computation on. (e.g., "cpu", "cuda")

        Returns:
            `torch.FloatTensor` of shape `(batch_size, max_feature_length, feature_size)`:
                The log mel-scaled spectrogram of the batched waveforms.
        """
        fft_window = torch.hamming_window(self.win_length, periodic=False, device=device, dtype=torch.float64)

        # batched implementation
        batch_size = waveform.shape[0]
        frames = waveform.unfold(-1, self.win_length, self.hop_length)

        # ---
        # the unbatched (and unpaded) original implementation skips last few audio values that can't be included in a frame
        # we need to ensure that the corresponding frames for the padded input also mask these values
        if batch_size > 1:
            frames = frames.clone()
            # concerned batch indices
            to_mask_batch_idxs = torch.arange(batch_size)[audio_lengths != audio_lengths.max()]
            if to_mask_batch_idxs.numel() > 0:
                batch_idxs_down = (audio_lengths[to_mask_batch_idxs] - self.win_length) // self.hop_length + 1
                batch_idxs_up = audio_lengths[to_mask_batch_idxs] // self.hop_length + 1
                offset_idx = batch_idxs_down.min()
                max_idx = batch_idxs_up.max()

                mask = torch.arange(max_idx - offset_idx, device=device).expand(to_mask_batch_idxs.shape[0], -1)
                mask = ((batch_idxs_down - offset_idx).unsqueeze(1) <= mask) & (
                    mask < (batch_idxs_up - offset_idx).unsqueeze(1)
                )
                mask = mask.unsqueeze(-1).expand(-1, -1, self.win_length)
                masked_frames = frames[to_mask_batch_idxs, offset_idx:max_idx].masked_fill_(mask, 0)
                frames[to_mask_batch_idxs, offset_idx:max_idx] = masked_frames
        # ---

        # apply pre-emphasis first order filter on fft windows
        frames_prev = torch.roll(frames, 1, dims=-1)
        frames_prev[:, :, 0] = frames_prev[:, :, 1]
        frames = (frames - self.preemphasis * frames_prev) * 32768

        # apply fft
        S = torch.fft.rfft(fft_window * frames.view(-1, self.win_length), n=self.n_fft, dim=1)
        S = S.view(frames.shape[0], -1, S.shape[-1])
        S = S.to(torch.complex64)

        spec = torch.abs(S)
        spec_power = spec**2

        # apply triangular mel filter bank
        mel_filters = torch.from_numpy(self.mel_filters).to(device, torch.float32)
        log_spec = torch.clamp(spec_power @ mel_filters, min=1.0)
        log_spec = torch.log(log_spec)

        return log_spec

    def _compute_audio_embed_size(self, audio_frames):
        integer = audio_frames // self.audio_compression_rate
        remainder = audio_frames % self.audio_compression_rate
        result = integer + (remainder > 0).to(integer.dtype)

        integer = result // self.audio_downsample_rate
        remainder = result % self.audio_downsample_rate
        result = integer + (remainder > 0).to(integer.dtype)  # qformer compression

        return result


__all__ = ["Phi4MultimodalFeatureExtractor"]

Phi4MultimodalFeatureExtractor.register_for_auto_class()