# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from statistics import NormalDist
from typing import Callable, Tuple

import numpy as np
import torch
from torch import nn
from tqdm import tqdm


class EDMScaling:
    def __init__(self, sigma_data: float = 0.5):
        self.sigma_data = sigma_data

    def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
        c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
        c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
        c_noise = 0.25 * sigma.log()
        return c_skip, c_out, c_in, c_noise


class EDMSDE:
    def __init__(
        self,
        p_mean: float = -1.2,
        p_std: float = 1.2,
        sigma_max: float = 80.0,
        sigma_min: float = 0.002,
    ):
        self.gaussian_dist = NormalDist(mu=p_mean, sigma=p_std)
        self.sigma_max = sigma_max
        self.sigma_min = sigma_min
        self._generator = np.random

    def sample_t(self, batch_size: int) -> torch.Tensor:
        cdf_vals = self._generator.uniform(size=(batch_size))
        samples_interval_gaussian = [self.gaussian_dist.inv_cdf(cdf_val) for cdf_val in cdf_vals]
        log_sigma = torch.tensor(samples_interval_gaussian, device="cuda")
        return torch.exp(log_sigma)

    def marginal_prob(self, x0: torch.Tensor, sigma: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        return x0, sigma


class EDMSampler(nn.Module):
    """
    Elucidating the Design Space of Diffusion-Based Generative Models (EDM)
    # https://github.com/NVlabs/edm/blob/62072d2612c7da05165d6233d13d17d71f213fee/generate.py#L25

    Attributes:
        None

    Methods:
        forward(x0_fn: Callable, x_sigma_max: torch.Tensor, num_steps: int = 35, sigma_min: float = 0.002,
                sigma_max: float = 80, rho: float = 7, S_churn: float = 0, S_min: float = 0,
                S_max: float = float("inf"), S_noise: float = 1) -> torch.Tensor:
            Performs the forward pass for the EDM sampling process.

            Parameters:
                x0_fn (Callable): A function that takes in a tensor and returns a denoised tensor.
                x_sigma_max (torch.Tensor): The initial noise level tensor.
                num_steps (int, optional): The number of sampling steps. Default is 35.
                sigma_min (float, optional): The minimum noise level. Default is 0.002.
                sigma_max (float, optional): The maximum noise level. Default is 80.
                rho (float, optional): The rho parameter for time step discretization. Default is 7.
                S_churn (float, optional): The churn parameter for noise increase. Default is 0.
                S_min (float, optional): The minimum value for the churn parameter. Default is 0.
                S_max (float, optional): The maximum value for the churn parameter. Default is float("inf").
                S_noise (float, optional): The noise scale for the churn parameter. Default is 1.

            Returns:
                torch.Tensor: The sampled tensor after the EDM process.
    """

    @torch.no_grad()
    def forward(
        self,
        x0_fn: Callable,
        x_sigma_max: torch.Tensor,
        num_steps: int = 35,
        sigma_min: float = 0.002,
        sigma_max: float = 80,
        rho: float = 7,
        S_churn: float = 0,
        S_min: float = 0,
        S_max: float = float("inf"),
        S_noise: float = 1,
    ) -> torch.Tensor:
        # Time step discretization.
        in_dtype = x_sigma_max.dtype
        _ones = torch.ones(x_sigma_max.shape[0], dtype=in_dtype, device=x_sigma_max.device)
        step_indices = torch.arange(num_steps, dtype=torch.float64, device=x_sigma_max.device)
        t_steps = (
            sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))
        ) ** rho
        t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])])  # t_N = 0

        # Main sampling loop.
        x_next = x_sigma_max.to(torch.float64)
        for i, (t_cur, t_next) in enumerate(
            tqdm(zip(t_steps[:-1], t_steps[1:], strict=False), total=len(t_steps) - 1)
        ):  # 0, ..., N-1
            x_cur = x_next

            # Increase noise temporarily.
            gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
            t_hat = t_cur + gamma * t_cur
            x_hat = x_cur + (t_hat**2 - t_cur**2).sqrt() * S_noise * torch.randn_like(x_cur)

            # Euler step.
            denoised = x0_fn(x_hat.to(in_dtype), t_hat.to(in_dtype) * _ones).to(torch.float64)
            d_cur = (x_hat - denoised) / t_hat
            x_next = x_hat + (t_next - t_hat) * d_cur

            # Apply 2nd order correction.
            if i < num_steps - 1:
                denoised = x0_fn(x_hat.to(in_dtype), t_hat.to(in_dtype) * _ones).to(torch.float64)
                d_prime = (x_next - denoised) / t_next
                x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)

        return x_next.to(in_dtype)
