# This code is modified from https://github.com/seanghay/uvr-mdx-infer/blob/main/separate.py
# Copyright (c) 2023 seanghay
# This code is from an Apache-2.0 repository.
# Note: This code has been modified to fit the context of this repository.
#       This code is included in an MIT-licensed repository.
#       The repository's MIT license does not apply to this code.
import torch
import numpy as np
import onnxruntime as ort
from tqdm import tqdm


class ConvTDFNet:
    """
    ConvTDFNet - Convolutional Temporal Frequency Domain Network.
    """

    def __init__(self, target_name, L, dim_f, dim_t, n_fft, hop=1024):
        """
        Initialize ConvTDFNet.

        Args:
            target_name (str): The target name for separation.
            L (int): Number of layers.
            dim_f (int): Dimension in the frequency domain.
            dim_t (int): Dimension in the time domain (log2).
            n_fft (int): FFT size.
            hop (int, optional): Hop size. Defaults to 1024.

        Returns:
            None
        """
        super(ConvTDFNet, self).__init__()
        self.dim_c = 4
        self.dim_f = dim_f
        self.dim_t = 2**dim_t
        self.n_fft = n_fft
        self.hop = hop
        self.n_bins = self.n_fft // 2 + 1
        self.chunk_size = hop * (self.dim_t - 1)
        self.window = torch.hann_window(window_length=self.n_fft, periodic=True)
        self.target_name = target_name

        out_c = self.dim_c * 4 if target_name == "*" else self.dim_c

        self.freq_pad = torch.zeros([1, out_c, self.n_bins - self.dim_f, self.dim_t])
        self.n = L // 2

    def stft(self, x):
        """
        Perform Short-Time Fourier Transform (STFT).

        Args:
            x (torch.Tensor): Input waveform.

        Returns:
            torch.Tensor: STFT of the input waveform.
        """
        x = x.reshape([-1, self.chunk_size])
        x = torch.stft(
            x,
            n_fft=self.n_fft,
            hop_length=self.hop,
            window=self.window,
            center=True,
            return_complex=True,
        )
        x = torch.view_as_real(x)
        x = x.permute([0, 3, 1, 2])
        x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape(
            [-1, self.dim_c, self.n_bins, self.dim_t]
        )
        return x[:, :, : self.dim_f]

    def istft(self, x, freq_pad=None):
        """
        Perform Inverse Short-Time Fourier Transform (ISTFT).

        Args:
            x (torch.Tensor): Input STFT.
            freq_pad (torch.Tensor, optional): Frequency padding. Defaults to None.

        Returns:
            torch.Tensor: Inverse STFT of the input.
        """
        freq_pad = (
            self.freq_pad.repeat([x.shape[0], 1, 1, 1])
            if freq_pad is None
            else freq_pad
        )
        x = torch.cat([x, freq_pad], -2)
        c = 4 * 2 if self.target_name == "*" else 2
        x = x.reshape([-1, c, 2, self.n_bins, self.dim_t]).reshape(
            [-1, 2, self.n_bins, self.dim_t]
        )
        x = x.permute([0, 2, 3, 1])
        x = x.contiguous()
        x = torch.view_as_complex(x)
        x = torch.istft(
            x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True
        )
        return x.reshape([-1, c, self.chunk_size])


class Predictor:
    """
    Predictor class for source separation using ConvTDFNet and ONNX Runtime.
    """

    def __init__(self, args, device):
        """
        Initialize the Predictor.

        Args:
            args (dict): Configuration arguments.
            device (str): Device to run the model ('cuda' or 'cpu').

        Returns:
            None

        Raises:
            ValueError: If the provided device is not 'cuda' or 'cpu'.
        """
        self.args = args
        self.model_ = ConvTDFNet(
            target_name="vocals",
            L=11,
            dim_f=args["dim_f"],
            dim_t=args["dim_t"],
            n_fft=args["n_fft"],
        )

        if device == "cuda":
            self.model = ort.InferenceSession(
                args["model_path"], providers=["CUDAExecutionProvider"]
            )
        elif device == "cpu":
            self.model = ort.InferenceSession(
                args["model_path"], providers=["CPUExecutionProvider"]
            )
        else:
            raise ValueError("Device must be either 'cuda' or 'cpu'")

    def demix(self, mix):
        """
        Separate the sources from the input mix.

        Args:
            mix (np.ndarray): Input mixture signal.

        Returns:
            np.ndarray: Separated sources.

        Raises:
            AssertionError: If margin is zero.
        """
        samples = mix.shape[-1]
        margin = self.args["margin"]
        chunk_size = self.args["chunks"] * 44100

        assert margin != 0, "Margin cannot be zero!"

        if margin > chunk_size:
            margin = chunk_size

        segmented_mix = {}

        if self.args["chunks"] == 0 or samples < chunk_size:
            chunk_size = samples

        counter = -1
        for skip in range(0, samples, chunk_size):
            counter += 1
            s_margin = 0 if counter == 0 else margin
            end = min(skip + chunk_size + margin, samples)
            start = skip - s_margin
            segmented_mix[skip] = mix[:, start:end].copy()
            if end == samples:
                break

        sources = self.demix_base(segmented_mix, margin_size=margin)
        return sources

    def demix_base(self, mixes, margin_size):
        """
        Base function for source separation.

        Args:
            mixes (dict): Dictionary of segmented mixtures.
            margin_size (int): Size of the margin.

        Returns:
            np.ndarray: Separated sources.
        """
        chunked_sources = []
        progress_bar = tqdm(total=len(mixes))
        progress_bar.set_description("Source separation")

        for mix in mixes:
            cmix = mixes[mix]
            sources = []
            n_sample = cmix.shape[1]
            model = self.model_
            trim = model.n_fft // 2
            gen_size = model.chunk_size - 2 * trim
            pad = gen_size - n_sample % gen_size
            mix_p = np.concatenate(
                (np.zeros((2, trim)), cmix, np.zeros((2, pad)), np.zeros((2, trim))), 1
            )
            mix_waves = []
            i = 0
            while i < n_sample + pad:
                waves = np.array(mix_p[:, i : i + model.chunk_size])
                mix_waves.append(waves)
                i += gen_size

            mix_waves = torch.tensor(np.array(mix_waves), dtype=torch.float32)

            with torch.no_grad():
                _ort = self.model
                spek = model.stft(mix_waves)
                if self.args["denoise"]:
                    spec_pred = (
                        -_ort.run(None, {"input": -spek.cpu().numpy()})[0] * 0.5
                        + _ort.run(None, {"input": spek.cpu().numpy()})[0] * 0.5
                    )
                    tar_waves = model.istft(torch.tensor(spec_pred))
                else:
                    tar_waves = model.istft(
                        torch.tensor(_ort.run(None, {"input": spek.cpu().numpy()})[0])
                    )
                tar_signal = (
                    tar_waves[:, :, trim:-trim]
                    .transpose(0, 1)
                    .reshape(2, -1)
                    .numpy()[:, :-pad]
                )

                start = 0 if mix == 0 else margin_size
                end = None if mix == list(mixes.keys())[::-1][0] else -margin_size

                if margin_size == 0:
                    end = None

                sources.append(tar_signal[:, start:end])

                progress_bar.update(1)

            chunked_sources.append(sources)
        _sources = np.concatenate(chunked_sources, axis=-1)

        progress_bar.close()
        return _sources

    def predict(self, mix):
        """
        Predict the separated sources from the input mix.

        Args:
            mix (np.ndarray): Input mixture signal.

        Returns:
            tuple: Tuple containing the mixture minus the separated sources and the separated sources.
        """
        if mix.ndim == 1:
            mix = np.asfortranarray([mix, mix])

        tail = mix.shape[1] % (self.args["chunks"] * 44100)
        if mix.shape[1] % (self.args["chunks"] * 44100) != 0:
            mix = np.pad(
                mix,
                (
                    (0, 0),
                    (
                        0,
                        self.args["chunks"] * 44100
                        - mix.shape[1] % (self.args["chunks"] * 44100),
                    ),
                ),
            )

        mix = mix.T
        sources = self.demix(mix.T)
        opt = sources[0].T

        if tail != 0:
            return ((mix - opt)[: -(self.args["chunks"] * 44100 - tail), :], opt)
        else:
            return ((mix - opt), opt)