# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Optional, Sequence, Tuple, Union

import torch

from .common import _get_storage_base


def get_stack_strides(
    tensors: Sequence[torch.Tensor], dim: int
) -> Optional[Tuple[Union[int, torch.SymInt], ...]]:
    """
    If the tensors are already stacked on dimension :code:`dim`, \
        returns the strides of the stacked tensors. \
        Otherwise returns :code:`None`.
    """
    if len(tensors) <= 1 or dim > tensors[0].ndim:
        return None

    final_stride = []
    for i in range(tensors[0].ndim + 1):
        if i == dim:
            # PyTorch 2.5 messed up the type annotations for SymInt, but 2.6 will fix it
            # https://github.com/pytorch/pytorch/issues/138478
            final_stride.append(
                tensors[1].storage_offset() - tensors[0].storage_offset()  # type: ignore[operator]
            )
            continue
        if i > dim:
            i -= 1
        final_stride.append(tensors[0].stride(i))

    storage_data_ptr: Optional[int] = None
    for i, x in enumerate(tensors[1:]):
        # Sanity checks
        if x.shape != tensors[0].shape:
            return None
        if x.stride() != tensors[0].stride():
            return None
        # PyTorch 2.5 messed up the type annotations for SymInt, but 2.6 will fix it
        # https://github.com/pytorch/pytorch/issues/138478
        if (
            x.storage_offset()
            != tensors[0].storage_offset() + (i + 1) * final_stride[dim]  # type: ignore[operator]
        ):
            return None
        if storage_data_ptr is None:
            storage_data_ptr = _get_storage_base(tensors[0])
        # Actual storage check
        if _get_storage_base(x) != storage_data_ptr:
            return None
    return tuple(final_stride)


def _stack_or_none_fw(
    tensors: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]],
    dim: int,
) -> Optional[torch.Tensor]:
    strides = get_stack_strides(tensors, dim)
    if strides is not None:
        input_shape = list(tensors[0].shape)
        input_shape.insert(dim, len(tensors))
        return tensors[0].as_strided(input_shape, strides)
    return None


def _stack_fw(
    tensors: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]],
    dim: int,
) -> torch.Tensor:
    out = _stack_or_none_fw(tensors, dim)
    if out is None:
        out = torch.stack(tensors, dim=dim)
    return out


class _Unbind(torch.autograd.Function):
    """
    See function `unbind`
    """

    @staticmethod
    # type: ignore
    def forward(ctx, x: torch.Tensor, dim: int):
        ctx.dim = dim
        return x.unbind(dim)

    @classmethod
    # type: ignore
    def backward(cls, ctx, *tensors: torch.Tensor):
        return _stack_fw(tensors, ctx.dim), None


class _StackOrNone(torch.autograd.Function):
    """
    See function `stack_or_none`
    """

    @staticmethod
    # type: ignore
    def forward(ctx, dim: int, *tensors: torch.Tensor):
        ctx.dim = dim
        return _stack_or_none_fw(tensors, dim=dim)

    @classmethod
    # type: ignore
    def backward(cls, ctx, grad: torch.Tensor):
        return (None, *grad.unbind(dim=ctx.dim))


def unbind(x: torch.Tensor, dim: int) -> Tuple[torch.Tensor, ...]:
    """
    Does exactly the same as :attr:`torch.unbind` for the forward.
    In backward, avoids a :attr:`torch.cat` if the gradients
    are already multiple views of the same storage
    """
    return _Unbind.apply(x, dim)


def stack_or_none(tensors: Sequence[torch.Tensor], dim: int) -> torch.Tensor:
    """
    Does exactly the same as :attr:`torch.stack` if the tensors can be concatenated
    without any memory operation. Otherwise returns None.
    """
    return _StackOrNone.apply(dim, *tensors)
