"""Batch-splitting adapter for the transformer.
Wraps an ``X0Model`` (or ``LayerStreamingWrapper``) and splits batched inputs
into smaller chunks before forwarding, then concatenates the results. This
controls peak activation memory at the cost of more forward passes.
The adapter is transparent — it has the same ``forward`` signature as
``X0Model`` and proxies attribute access to the wrapped model.
Example
-------
>>> from ltx_core.batch_split import BatchSplitAdapter
>>> adapter = BatchSplitAdapter(model, max_batch_size=1)
>>> # Receives B=4, runs 4xB=1 internally, returns B=4
>>> denoised_video, denoised_audio = adapter(video=v_b4, audio=a_b4, perturbations=ptb)
"""

from __future__ import annotations

from typing import Any

import torch
from torch import nn

from ltx_core.guidance.perturbations import BatchedPerturbationConfig
from ltx_core.model.transformer.modality import Modality


def _split_perturbations(config: BatchedPerturbationConfig, sizes: list[int]) -> list[BatchedPerturbationConfig]:
    """Split a ``BatchedPerturbationConfig`` along the batch dimension."""
    it = iter(config.perturbations)
    return [BatchedPerturbationConfig([next(it) for _ in range(s)]) for s in sizes]


def _merge_tensors(tensors: list[torch.Tensor | None]) -> torch.Tensor | None:
    """Concatenate tensors along batch dim, or return None if all are None."""
    non_none = [t for t in tensors if t is not None]
    if not non_none:
        return None
    return torch.cat(non_none, dim=0)


class BatchSplitAdapter(nn.Module):
    """Wraps a model and splits batched forward calls into smaller chunks.
    Has the same ``forward`` signature as ``X0Model``:
    ``(video, audio, perturbations) -> (denoised_video, denoised_audio)``.
    Args:
        model: The model to wrap (``X0Model``, ``LayerStreamingWrapper``, etc.).
        max_batch_size: Maximum batch size per forward pass. Input batches
            larger than this are split into sequential chunks.
    """

    def __init__(self, model: nn.Module, max_batch_size: int) -> None:
        if max_batch_size < 1:
            raise ValueError(f"max_batch_size must be >= 1, got {max_batch_size}")
        super().__init__()
        self._model = model
        self._max_batch_size = max_batch_size

    def _get_chunk_sizes(self, batch_size: int) -> list[int]:
        full, remainder = divmod(batch_size, self._max_batch_size)
        sizes = [self._max_batch_size] * full
        if remainder:
            sizes.append(remainder)
        return sizes

    def forward(
        self,
        video: Modality | None,
        audio: Modality | None,
        perturbations: BatchedPerturbationConfig,
    ) -> tuple[torch.Tensor | None, torch.Tensor | None]:
        batch_size = (video or audio).latent.shape[0]

        if batch_size <= self._max_batch_size:
            return self._model(video=video, audio=audio, perturbations=perturbations)

        sizes = self._get_chunk_sizes(batch_size)
        n = len(sizes)

        v_chunks = video.split(sizes) if video is not None else [None] * n
        a_chunks = audio.split(sizes) if audio is not None else [None] * n
        p_chunks = _split_perturbations(perturbations, sizes)

        chunk_results = [
            self._model(video=vc, audio=ac, perturbations=pc)
            for vc, ac, pc in zip(v_chunks, a_chunks, p_chunks, strict=True)
        ]

        results_v, results_a = zip(*chunk_results, strict=True)
        return _merge_tensors(list(results_v)), _merge_tensors(list(results_a))

    def __getattr__(self, name: str) -> Any:  # noqa: ANN401
        """Proxy attribute access to the wrapped model."""
        try:
            return super().__getattr__(name)
        except AttributeError:
            return getattr(self._model, name)
