"""Multi-resolution STFT loss for general-purpose codec quality assessment."""

from __future__ import annotations

import torch


def _stft(x: torch.Tensor, fft_size: int, hop_size: int, win_size: int) -> torch.Tensor:
    """Compute STFT magnitude. Input [B, T] -> Output [B, F, T']."""
    window = torch.hann_window(win_size, device=x.device, dtype=x.dtype)
    stft_out = torch.stft(
        x, fft_size, hop_length=hop_size, win_length=win_size,
        window=window, return_complex=True,
    )
    return stft_out.abs()


def spectral_convergence(ref_mag: torch.Tensor, deg_mag: torch.Tensor) -> torch.Tensor:
    """Spectral convergence: ||ref - deg||_F / ||ref||_F."""
    return (ref_mag - deg_mag).norm(dim=(-2, -1)) / ref_mag.norm(dim=(-2, -1)).clamp(min=1e-8)


def log_stft_magnitude_loss(ref_mag: torch.Tensor, deg_mag: torch.Tensor) -> torch.Tensor:
    """L1 of log STFT magnitudes."""
    ref_log = torch.log(ref_mag.clamp(min=1e-7))
    deg_log = torch.log(deg_mag.clamp(min=1e-7))
    return (ref_log - deg_log).abs().mean(dim=(-2, -1))


def multi_resolution_stft_loss(
    ref: torch.Tensor,
    deg: torch.Tensor,
    fft_sizes: tuple[int, ...] = (512, 1024, 2048),
    hop_sizes: tuple[int, ...] = (128, 256, 512),
    win_sizes: tuple[int, ...] = (512, 1024, 2048),
) -> float:
    """Averaged multi-resolution STFT loss (spectral convergence + log mag).

    Inputs: [B, 1, T] or [B, T] at same sample rate, truncated to min length.
    """
    if ref.ndim == 3:
        ref = ref.squeeze(1)
    if deg.ndim == 3:
        deg = deg.squeeze(1)

    min_len = min(ref.shape[-1], deg.shape[-1])
    ref = ref[..., :min_len]
    deg = deg[..., :min_len]

    total = 0.0
    n = len(fft_sizes)
    for fft_s, hop_s, win_s in zip(fft_sizes, hop_sizes, win_sizes):
        ref_mag = _stft(ref, fft_s, hop_s, win_s)
        deg_mag = _stft(deg, fft_s, hop_s, win_s)
        sc = spectral_convergence(ref_mag, deg_mag).mean().item()
        lm = log_stft_magnitude_loss(ref_mag, deg_mag).mean().item()
        total += sc + lm

    return total / n
