"""High-frequency energy retention metric.

Measures how much energy above a cutoff frequency is preserved
after encode/decode, reported as delta in dB.
"""

from __future__ import annotations

import torch


def _band_energy(wav: torch.Tensor, sr: int, low_hz: float, high_hz: float) -> torch.Tensor:
    """Energy in [low_hz, high_hz] band via FFT. Input [B, T], returns [B]."""
    n = wav.shape[-1]
    spectrum = torch.fft.rfft(wav, dim=-1)
    magnitude_sq = (spectrum.real ** 2 + spectrum.imag ** 2)

    freqs = torch.fft.rfftfreq(n, d=1.0 / sr, device=wav.device)
    mask = (freqs >= low_hz) & (freqs <= high_hz)
    return magnitude_sq[:, mask].sum(dim=-1)


def hf_energy_ratio(wav: torch.Tensor, sr: int, cutoff_hz: float | None = None) -> torch.Tensor:
    """Ratio of energy above cutoff to total energy. Returns [B].

    Default cutoff: 4000 Hz for 16kHz SR, 6000 Hz for 24kHz SR.
    """
    if cutoff_hz is None:
        cutoff_hz = 4000.0 if sr <= 16000 else 6000.0

    nyquist = sr / 2.0
    total_energy = _band_energy(wav, sr, 0.0, nyquist)
    hf_energy = _band_energy(wav, sr, cutoff_hz, nyquist)
    return hf_energy / total_energy.clamp(min=1e-12)


def hf_energy_delta_db(
    ref: torch.Tensor,
    deg: torch.Tensor,
    sr: int,
    cutoff_hz: float | None = None,
) -> float:
    """Delta in HF energy ratio (dB) between reference and reconstruction.

    Negative = HF loss, positive = HF gain.
    Inputs: [B, 1, T], truncated to min length.
    """
    if ref.ndim == 3:
        ref = ref.squeeze(1)
    if deg.ndim == 3:
        deg = deg.squeeze(1)

    min_len = min(ref.shape[-1], deg.shape[-1])
    ref = ref[..., :min_len]
    deg = deg[..., :min_len]

    ratio_ref = hf_energy_ratio(ref, sr, cutoff_hz)
    ratio_deg = hf_energy_ratio(deg, sr, cutoff_hz)

    delta = 10.0 * torch.log10(ratio_deg.clamp(min=1e-12) / ratio_ref.clamp(min=1e-12))
    return delta.mean().item()
