"""Batching utilities: fixed-window cropping, padding, and bucketing."""

from __future__ import annotations

import torch


def pad_or_crop(wav: torch.Tensor, target_len: int) -> torch.Tensor:
    """Pad (zero) or crop waveform to exact target_len along last dim.

    Args:
        wav: [..., T]
        target_len: desired T
    """
    t = wav.shape[-1]
    if t >= target_len:
        return wav[..., :target_len]
    pad_amount = target_len - t
    return torch.nn.functional.pad(wav, (0, pad_amount))


def fixed_window_batch(
    wavs: list[torch.Tensor],
    window_seconds: float,
    sr: int,
    device: str = "cpu",
) -> torch.Tensor:
    """Crop/pad a list of waveforms to fixed-length windows and stack.

    Args:
        wavs: list of tensors, each [1, T_i] or [T_i]
        window_seconds: target duration
        sr: sample rate
        device: target device

    Returns:
        Batched tensor [B, 1, target_len]
    """
    target_len = int(window_seconds * sr)
    batch = []
    for w in wavs:
        if w.ndim == 1:
            w = w.unsqueeze(0)
        w = pad_or_crop(w, target_len)
        batch.append(w)
    return torch.stack(batch, dim=0).to(device)


def bucket_by_length(
    lengths: list[int],
    bucket_boundaries: list[int],
) -> list[list[int]]:
    """Group sample indices into length buckets for efficient batching.

    Args:
        lengths: sample lengths in samples
        bucket_boundaries: sorted list of upper bounds per bucket

    Returns:
        List of index lists, one per bucket.
    """
    buckets: list[list[int]] = [[] for _ in range(len(bucket_boundaries) + 1)]
    for idx, length in enumerate(lengths):
        placed = False
        for bi, boundary in enumerate(bucket_boundaries):
            if length <= boundary:
                buckets[bi].append(idx)
                placed = True
                break
        if not placed:
            buckets[-1].append(idx)
    return [b for b in buckets if b]
