"""Log-mel spectrogram L1 loss for codec reconstruction quality."""

from __future__ import annotations

import torch
import torchaudio


def _get_mel_transform(
    sr: int,
    n_fft: int = 1024,
    hop_length: int = 256,
    n_mels: int = 80,
    device: str = "cpu",
) -> torchaudio.transforms.MelSpectrogram:
    return torchaudio.transforms.MelSpectrogram(
        sample_rate=sr,
        n_fft=n_fft,
        hop_length=hop_length,
        n_mels=n_mels,
        power=2.0,
    ).to(device)


def log_mel_spectrogram(
    wav: torch.Tensor,
    sr: int,
    n_fft: int = 1024,
    hop_length: int = 256,
    n_mels: int = 80,
) -> torch.Tensor:
    """Compute log-mel spectrogram. Input: [..., T]. Output: [..., n_mels, T']."""
    mel_fn = _get_mel_transform(sr, n_fft, hop_length, n_mels, wav.device)
    mel = mel_fn(wav)
    return torch.log(mel.clamp(min=1e-7))


def log_mel_l1(
    ref: torch.Tensor,
    deg: torch.Tensor,
    sr: int,
    n_fft: int = 1024,
    hop_length: int = 256,
    n_mels: int = 80,
) -> float:
    """L1 distance between log-mel spectrograms of reference and degraded audio.

    Inputs should be same shape [B, 1, T] at the same sample rate.
    Shorter signal is zero-padded to match the longer one.
    """
    min_len = min(ref.shape[-1], deg.shape[-1])
    ref = ref[..., :min_len]
    deg = deg[..., :min_len]

    # squeeze channel dim for mel computation: [B, T]
    ref_2d = ref.squeeze(1)
    deg_2d = deg.squeeze(1)

    mel_ref = log_mel_spectrogram(ref_2d, sr, n_fft, hop_length, n_mels)
    mel_deg = log_mel_spectrogram(deg_2d, sr, n_fft, hop_length, n_mels)

    return (mel_ref - mel_deg).abs().mean().item()
