"""Audio loading, resampling, normalization utilities."""

from __future__ import annotations

from pathlib import Path

import torch
import torchaudio


def load_audio(
    path: str | Path,
    target_sr: int | None = None,
    mono: bool = True,
) -> tuple[torch.Tensor, int]:
    """Load audio file and optionally resample + convert to mono.

    Returns:
        wav: float32 tensor, shape [1, T] (mono) or [C, T]
        sr:  output sample rate
    """
    wav, sr = torchaudio.load(str(path))

    if mono and wav.shape[0] > 1:
        wav = wav.mean(dim=0, keepdim=True)

    if target_sr is not None and sr != target_sr:
        wav = torchaudio.functional.resample(wav, sr, target_sr)
        sr = target_sr

    return wav, sr


def normalize_peak(wav: torch.Tensor, target_db: float = -1.0) -> torch.Tensor:
    """Peak-normalize waveform to target_db below 0 dBFS."""
    peak = wav.abs().max()
    if peak < 1e-8:
        return wav
    target_linear = 10 ** (target_db / 20.0)
    return wav * (target_linear / peak)


def normalize_rms(wav: torch.Tensor, target_db: float = -20.0) -> torch.Tensor:
    """RMS-normalize waveform to target_db."""
    rms = wav.pow(2).mean().sqrt()
    if rms < 1e-8:
        return wav
    target_linear = 10 ** (target_db / 20.0)
    return wav * (target_linear / rms)


def ensure_shape(wav: torch.Tensor) -> torch.Tensor:
    """Ensure wav is [B, 1, T] (add batch and/or channel dims as needed)."""
    if wav.ndim == 1:
        wav = wav.unsqueeze(0).unsqueeze(0)
    elif wav.ndim == 2:
        wav = wav.unsqueeze(0)
    assert wav.ndim == 3, f"Expected 3-D [B, 1, T], got shape {wav.shape}"
    return wav


def generate_synthetic(
    duration_s: float,
    sr: int,
    batch_size: int = 1,
    device: str = "cpu",
) -> torch.Tensor:
    """Generate synthetic speech-like signal for warmup / smoke testing.

    Mix of a few sinusoids + noise — not real speech, but exercises
    the full encode/decode path including HF content.
    """
    n_samples = int(duration_s * sr)
    t = torch.linspace(0, duration_s, n_samples, device=device)

    signal = torch.zeros(batch_size, 1, n_samples, device=device)
    for freq in [200.0, 800.0, 2500.0, 5000.0]:
        signal += 0.15 * torch.sin(2 * torch.pi * freq * t)
    signal += 0.05 * torch.randn_like(signal)

    peak = signal.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8)
    signal = signal / peak * 0.9
    return signal
