from collections.abc import Iterator
from contextlib import contextmanager
from typing import TypeVar

import torch

from ltx_pipelines.utils.helpers import cleanup_memory

_M = TypeVar("_M", bound=torch.nn.Module)


@contextmanager
def gpu_model(model: _M) -> Iterator[_M]:
    """Context manager that yields a model and releases its memory on exit.
    Moves all parameters and buffers to ``meta`` device on exit, which
    immediately releases the underlying storage on **both** GPU and CPU,
    then runs ``cleanup_memory()`` to reclaim fragmented CUDA memory.
    Usage::
        with gpu_model(build_encoder()) as encoder:
            ...  # use encoder — typed as the concrete class
        # GPU + CPU memory freed automatically
    """
    try:
        yield model
    finally:
        torch.cuda.synchronize()
        # .to("meta") releases storage for all parameters/buffers regardless
        # of their original device (CUDA or CPU).
        model.to("meta")
        cleanup_memory()
