"""Scale-Invariant Signal-to-Distortion Ratio (SI-SDR)."""

from __future__ import annotations

import torch


def si_sdr(ref: torch.Tensor, deg: torch.Tensor) -> float:
    """Compute SI-SDR in dB, averaged over the batch.

    Args:
        ref: reference waveform [B, 1, T]
        deg: degraded/reconstructed waveform [B, 1, T]

    Truncates to min length, works per-sample then averages.
    """
    if ref.ndim == 3:
        ref = ref.squeeze(1)
    if deg.ndim == 3:
        deg = deg.squeeze(1)

    min_len = min(ref.shape[-1], deg.shape[-1])
    ref = ref[..., :min_len]
    deg = deg[..., :min_len]

    ref = ref - ref.mean(dim=-1, keepdim=True)
    deg = deg - deg.mean(dim=-1, keepdim=True)

    dot = (ref * deg).sum(dim=-1, keepdim=True)
    s_ref_sq = (ref ** 2).sum(dim=-1, keepdim=True).clamp(min=1e-8)

    # s_target = <deg, ref> / <ref, ref> * ref
    s_target = (dot / s_ref_sq) * ref
    e_noise = deg - s_target

    si_sdr_val = 10 * torch.log10(
        (s_target ** 2).sum(dim=-1) / (e_noise ** 2).sum(dim=-1).clamp(min=1e-8)
    )
    return si_sdr_val.mean().item()
