# Modified from https://github.com/echocatzh/conv-stft/blob/master/conv_stft/conv_stft.py

# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# MIT License

# Copyright (c) 2020 Shimin Zhang

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import torch as th
import torch.nn.functional as F
from scipy.signal import check_COLA, get_window


support_clp_op = None
if th.__version__ >= "1.7.0":
    from torch.fft import rfft as fft

    support_clp_op = True
else:
    from torch import rfft as fft


class STFT(th.nn.Module):
    def __init__(
        self,
        win_len=1024,
        win_hop=512,
        fft_len=1024,
        enframe_mode="continue",
        win_type="hann",
        win_sqrt=False,
        pad_center=True,
    ):
        """
        Implement of STFT using 1D convolution and 1D transpose convolutions.
        Implement of framing the signal in 2 ways, `break` and `continue`.
        `break` method is a kaldi-like framing.
        `continue` method is a librosa-like framing.

        More information about `perfect reconstruction`:
        1. https://ww2.mathworks.cn/help/signal/ref/stft.html
        2. https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.get_window.html

        Args:
            win_len (int): Number of points in one frame.  Defaults to 1024.
            win_hop (int): Number of framing stride. Defaults to 512.
            fft_len (int): Number of DFT points. Defaults to 1024.
            enframe_mode (str, optional): `break` and `continue`. Defaults to 'continue'.
            win_type (str, optional): The type of window to create. Defaults to 'hann'.
            win_sqrt (bool, optional): using square root window. Defaults to True.
            pad_center (bool, optional): `perfect reconstruction` opts. Defaults to True.
        """
        super(STFT, self).__init__()
        assert enframe_mode in ["break", "continue"]
        assert fft_len >= win_len
        self.win_len = win_len
        self.win_hop = win_hop
        self.fft_len = fft_len
        self.mode = enframe_mode
        self.win_type = win_type
        self.win_sqrt = win_sqrt
        self.pad_center = pad_center
        self.pad_amount = self.fft_len // 2

        en_k, fft_k, ifft_k, ola_k = self.__init_kernel__()
        self.register_buffer("en_k", en_k)
        self.register_buffer("fft_k", fft_k)
        self.register_buffer("ifft_k", ifft_k)
        self.register_buffer("ola_k", ola_k)

    def __init_kernel__(self):
        """
        Generate enframe_kernel, fft_kernel, ifft_kernel and overlap-add kernel.
        ** enframe_kernel: Using conv1d layer and identity matrix.
        ** fft_kernel: Using linear layer for matrix multiplication. In fact,
        enframe_kernel and fft_kernel can be combined, But for the sake of
        readability, I took the two apart.
        ** ifft_kernel, pinv of fft_kernel.
        ** overlap-add kernel, just like enframe_kernel, but transposed.

        Returns:
            tuple: four kernels.
        """
        enframed_kernel = th.eye(self.fft_len)[:, None, :]
        if support_clp_op:
            tmp = fft(th.eye(self.fft_len))
            fft_kernel = th.stack([tmp.real, tmp.imag], dim=2)
        else:
            fft_kernel = fft(th.eye(self.fft_len), 1)
        if self.mode == "break":
            enframed_kernel = th.eye(self.win_len)[:, None, :]
            fft_kernel = fft_kernel[: self.win_len]
        fft_kernel = th.cat((fft_kernel[:, :, 0], fft_kernel[:, :, 1]), dim=1)
        ifft_kernel = th.pinverse(fft_kernel)[:, None, :]
        window = get_window(self.win_type, self.win_len)

        self.perfect_reconstruct = check_COLA(window, self.win_len, self.win_len - self.win_hop)
        window = th.FloatTensor(window)
        if self.mode == "continue":
            left_pad = (self.fft_len - self.win_len) // 2
            right_pad = left_pad + (self.fft_len - self.win_len) % 2
            window = F.pad(window, (left_pad, right_pad))
        if self.win_sqrt:
            self.padded_window = window
            window = th.sqrt(window)
        else:
            self.padded_window = window**2

        fft_kernel = fft_kernel.T * window
        ifft_kernel = ifft_kernel * window
        ola_kernel = th.eye(self.fft_len)[: self.win_len, None, :]
        if self.mode == "continue":
            ola_kernel = th.eye(self.fft_len)[:, None, : self.fft_len]
        return enframed_kernel, fft_kernel, ifft_kernel, ola_kernel

    def is_perfect(self):
        """
        Whether the parameters win_len, win_hop and win_sqrt
        obey constants overlap-add(COLA)

        Returns:
            bool: Return true if parameters obey COLA.
        """
        return self.perfect_reconstruct and self.pad_center

    def transform(self, inputs, return_type="complex"):
        """Take input data (audio) to STFT domain.

        Args:
            inputs (tensor): Tensor of floats, with shape (num_batch, num_samples)
            return_type (str, optional): return (mag, phase) when `magphase`,
            return (real, imag) when `realimag` and complex(real, imag) when `complex`.
            Defaults to 'complex'.

        Returns:
            tuple: (mag, phase) when `magphase`, return (real, imag) when
            `realimag`. Defaults to 'complex', each elements with shape
            [num_batch, num_frequencies, num_frames]
        """
        assert return_type in ["magphase", "realimag", "complex"]
        if inputs.dim() == 2:
            inputs = th.unsqueeze(inputs, 1)
        self.num_samples = inputs.size(-1)
        if self.pad_center:
            inputs = F.pad(inputs, (self.pad_amount, self.pad_amount), mode="reflect")
        enframe_inputs = F.conv1d(inputs, self.en_k, stride=self.win_hop)
        outputs = th.transpose(enframe_inputs, 1, 2)
        outputs = F.linear(outputs, self.fft_k)
        outputs = th.transpose(outputs, 1, 2)
        dim = self.fft_len // 2 + 1
        real = outputs[:, :dim, :]
        imag = outputs[:, dim:, :]
        if return_type == "realimag":
            return real, imag
        elif return_type == "complex":
            assert support_clp_op
            return th.complex(real, imag)
        else:
            mags = th.sqrt(real**2 + imag**2)
            phase = th.atan2(imag, real)
            return mags, phase

    def inverse(self, input1, input2=None, input_type="magphase"):
        """Call the inverse STFT (iSTFT), given tensors produced
        by the `transform` function.

        Args:
            input1 (tensors): Magnitude/Real-part of STFT with shape
            [num_batch, num_frequencies, num_frames]
            input2 (tensors): Phase/Imag-part of STFT with shape
            [num_batch, num_frequencies, num_frames]
            input_type (str, optional): Mathematical meaning of input tensor's.
            Defaults to 'magphase'.

        Returns:
            tensors: Reconstructed audio given magnitude and phase. Of
                shape [num_batch, num_samples]
        """
        assert input_type in ["magphase", "realimag"]
        if input_type == "realimag":
            real, imag = None, None
            if support_clp_op and th.is_complex(input1):
                real, imag = input1.real, input1.imag
            else:
                real, imag = input1, input2
        else:
            real = input1 * th.cos(input2)
            imag = input1 * th.sin(input2)
        inputs = th.cat([real, imag], dim=1)
        outputs = F.conv_transpose1d(inputs, self.ifft_k, stride=self.win_hop)
        t = (self.padded_window[None, :, None]).repeat(1, 1, inputs.size(-1))
        t = t.to(inputs.device)
        coff = F.conv_transpose1d(t, self.ola_k, stride=self.win_hop)

        num_frames = input1.size(-1)
        num_samples = num_frames * self.win_hop

        rm_start, rm_end = self.pad_amount, self.pad_amount + num_samples

        outputs = outputs[..., rm_start:rm_end]
        coff = coff[..., rm_start:rm_end]
        coffidx = th.where(coff > 1e-8)
        outputs[coffidx] = outputs[coffidx] / (coff[coffidx])
        return outputs.squeeze(dim=1)

    def forward(self, inputs):
        """Take input data (audio) to STFT domain and then back to audio.

        Args:
            inputs (tensor): Tensor of floats, with shape [num_batch, num_samples]

        Returns:
            tensor: Reconstructed audio given magnitude and phase.
            Of shape [num_batch, num_samples]
        """
        mag, phase = self.transform(inputs)
        rec_wav = self.inverse(mag, phase)
        return rec_wav
