# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.


import torch

from xformers.ops import masked_matmul
from xformers.sparse import _csr_ops
from xformers.sparse.utils import (
    _csr_to_coo,
    _dense3d_to_sparse,
    _diffsort,
    _get_transpose_info,
    _transpose_with_info,
)


class SparseCSRTensor(torch.Tensor):
    @staticmethod
    def __new__(cls, row_offsets, column_indices, values, shape):
        kwargs = {}
        kwargs["device"] = values.device
        kwargs["dtype"] = values.dtype
        kwargs["layout"] = values.layout
        kwargs["requires_grad"] = values.requires_grad
        assert len(shape) == 3
        return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)

    def __init__(self, row_offsets, column_indices, values, shape):
        assert row_offsets.ndim == 1
        assert column_indices.ndim == 1
        assert values.ndim == 2

        self.__row_offsets = row_offsets.contiguous()
        self.__row_indices = _diffsort(row_offsets).to(row_offsets.dtype)
        self.__column_indices = column_indices.contiguous()
        self.__values = values.contiguous()

        self.__transp_info = _get_transpose_info(
            self.shape[1],
            self.shape[2],
            self.__row_indices,
            self.__row_offsets,
            self.__column_indices,
        )

    def __repr__(self):
        return f"sparse_csr_tensor(shape={self.shape}, values={self.__values})"

    @classmethod
    def from_dense(cls, matrix):
        values, row_indices, row_offsets, column_indices = _dense3d_to_sparse(
            matrix, matrix.device
        )
        return cls(row_offsets, column_indices, values, matrix.shape)

    @classmethod
    def from_sparse_coo(cls, arg0):
        """
        assert arg0.is_sparse
        x = arg0.coalesce()
        rows, cols = x.indices().unbind(0)
        vals = x.values()
        _coo_to_csr()
        """
        pass

    @classmethod
    def _wrap(
        cls, shape, values, row_indices, row_offsets, column_indices, _transp_info
    ):
        matrix = cls.__new__(cls, row_offsets, column_indices, values, shape)
        matrix.__values = values
        matrix.__row_indices = row_indices
        matrix.__row_offsets = row_offsets
        matrix.__column_indices = column_indices
        matrix.__transp_info = _transp_info
        return matrix

    def values(self):
        return self.__values

    @property
    def _csr_row_indices(self):
        return self.__row_indices

    @property
    def _csr_row_offsets(self):
        return self.__row_offsets

    @property
    def _csr_column_indices(self):
        return self.__column_indices

    @property
    def _csr_transp_info(self):
        return self.__transp_info

    @classmethod
    def _bmm(cls, arg0, arg1):
        if not (isinstance(arg0, cls) and type(arg1) is torch.Tensor):
            return NotImplemented

        assert arg0.ndim == 3
        assert arg1.ndim == 3

        self = arg0
        b = arg1

        _, m, n = self.shape
        row_indices = self.__row_indices
        values = self.__values
        row_offsets = self.__row_offsets
        column_indices = self.__column_indices

        out = _csr_ops._spmm.apply(
            b, row_indices, values, row_offsets, column_indices, m, self.__transp_info
        )
        return out

    @classmethod
    def _softmax(cls, arg0, dim):
        if not (dim == -1 or dim == 2):
            return NotImplemented

        self = arg0
        _, m, n = self.shape
        row_indices = self.__row_indices
        values = self.__values
        row_offsets = self.__row_offsets
        column_indices = self.__column_indices
        out = _csr_ops._SparseSoftmax.apply(
            m, n, row_indices, values, row_offsets, column_indices
        )
        return cls._wrap(
            self.shape,
            out,
            row_indices,
            row_offsets,
            column_indices,
            self.__transp_info,
        )

    @classmethod
    def _transpose(cls, arg0, dim0, dim1):
        # TODO: check if need to return this or not
        if not (dim0 == 1 or dim0 == -2):
            return NotImplemented
        if not (dim1 == 2 or dim1 == -1):
            return NotImplemented

        B, m, n = arg0.shape
        values = arg0.__values

        (
            output_row_indices,
            output_values,
            output_row_offsets,
            output_column_indices,
        ) = _transpose_with_info(values, arg0.__transp_info)
        new_transp_info = _get_transpose_info(
            n, m, output_row_indices, output_row_offsets, output_column_indices
        )

        return cls._wrap(
            (B, n, m),
            output_values,
            output_row_indices,
            output_row_offsets,
            output_column_indices,
            new_transp_info,
        )

    @classmethod
    def _masked_matmul(cls, a, b, mask):
        if not (type(a) is torch.Tensor and type(b) is torch.Tensor):
            return NotImplemented
        assert mask.shape[1] == a.shape[1]
        assert mask.shape[2] == b.shape[2]
        row_indices = mask.__row_indices
        row_offsets = mask.__row_offsets
        column_indices = mask.__column_indices
        a = a.contiguous()
        out = _csr_ops._sddmm.apply(
            a,
            b.transpose(-2, -1).contiguous(),
            row_indices,
            row_offsets,
            column_indices,
            mask.__transp_info,
        )
        # TODO add bias here
        return cls._wrap(
            mask.shape,
            out,
            row_indices,
            row_offsets,
            column_indices,
            mask.__transp_info,
        )

    @classmethod
    def _to(cls, arg0, device):
        if isinstance(device, str):
            device = torch.device(device)
        assert isinstance(device, torch.device)
        return cls._wrap(
            arg0.shape,
            arg0.__values.to(device=device),
            arg0.__row_indices.to(device=device),
            arg0.__row_offsets.to(device=device),
            arg0.__column_indices.to(device=device),
            tuple(t.to(device=device) for t in arg0.__transp_info),
        )

    @classmethod
    def _copy(cls, arg0, arg1):
        if not (isinstance(arg0, cls) and isinstance(arg1, cls)):
            return NotImplemented
        assert arg0.shape == arg1.shape
        av0, av1 = arg0.__values, arg1.__values
        av0.resize_as_(av1).copy_(av1)
        av0, av1 = arg0.__row_indices, arg1.__row_indices
        av0.resize_as_(av1).copy_(av1)
        av0, av1 = arg0.__row_offsets, arg1.__row_offsets
        av0.resize_as_(av1).copy_(av1)
        av0, av1 = arg0.__column_indices, arg1.__column_indices
        av0.resize_as_(av1).copy_(av1)
        for v0, v1 in zip(arg0.__transp_info, arg1.__transp_info):
            v0.resize_as_(v1).copy_(v1)
        return arg0

    @classmethod
    def _equal(cls, arg0, arg1):
        if not (isinstance(arg0, cls) and isinstance(arg1, cls)):
            return NotImplemented
        if arg0.shape != arg1.shape:
            return False
        if not torch.equal(arg0.__values, arg1.__values):
            return False
        if not torch.equal(arg0.__row_offsets, arg1.__row_offsets):
            return False
        if not torch.equal(arg0.__column_indices, arg1.__column_indices):
            return False
        return True

    @classmethod
    def _to_dense(cls, arg0):
        _, m, n = arg0.shape
        shape = arg0.shape
        matrix = torch.zeros(shape, dtype=arg0.dtype, device=arg0.device)
        row_offsets = arg0.__row_offsets.long()
        column_indices = arg0.__column_indices.long()
        row_coo, _ = _csr_to_coo(m, n, row_offsets, column_indices)
        b_idxs = torch.arange(len(arg0.__values), device=arg0.device)[:, None]
        matrix[b_idxs, row_coo, column_indices] = arg0.__values
        return matrix

    @classmethod
    def _binary_op(cls, func, arg0, arg1):
        if not (
            isinstance(arg0, (cls, int, float)) and isinstance(arg1, (cls, int, float))
        ):
            return NotImplemented
        v0, v1 = arg0, arg1
        if isinstance(arg0, cls):
            v0 = arg0.__values
        if isinstance(arg1, cls):
            v1 = arg1.__values
        # assert arg0.shape == arg1.shape
        if isinstance(arg0, cls) and isinstance(arg1, cls):
            msg = f"arg0 and arg1 need to have the same sparsity pattern in {func} (for now)"
            if not arg0.__row_offsets.shape == arg1.__row_offsets.shape:
                raise NotImplementedError(msg)
            if not arg0.__column_indices.shape == arg1.__column_indices.shape:
                raise NotImplementedError(msg)
            if not arg0.__values.shape == arg1.__values.shape:
                raise NotImplementedError(msg)
            # TODO this is not always true, but is a fast approximation for now
            if arg0.__row_offsets is not arg1.__row_offsets:
                raise NotImplementedError(msg)
            if arg0.__column_indices is not arg1.__column_indices:
                raise NotImplementedError(msg)
        out = func(v0, v1)
        return cls._wrap(
            arg0.shape,
            out,
            arg0.__row_indices,
            arg0.__row_offsets,
            arg0.__column_indices,
            arg0.__transp_info,
        )

    @classmethod
    def _binary_op_slow(cls, func, arg0, arg1):
        # assert arg0.shape == arg1.shape
        v0, v1 = arg0, arg1
        if isinstance(arg0, cls):
            v0 = arg0.to_dense()
        if isinstance(arg1, cls):
            v1 = arg1.to_dense()
        out = func(v0, v1)
        return cls.from_dense(out)

    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        if func in [
            torch.Tensor.bmm,
            torch.bmm,
            torch.Tensor.__matmul__,
            torch.matmul,
            torch.Tensor.matmul,
        ]:
            assert len(args) == 2
            return cls._bmm(args[0], args[1])

        if func in [torch.Tensor.softmax, torch.nn.functional.softmax, torch.softmax]:
            return cls._softmax(args[0], kwargs["dim"])

        if func in [torch.Tensor.transpose, torch.transpose]:
            assert len(kwargs) == 0
            return cls._transpose(args[0], args[1], args[2])

        if func == masked_matmul:
            assert len(args) == 3
            return cls._masked_matmul(args[0], args[1], args[2])

        if func in [
            torch.Tensor.add,
            torch.add,
            torch.Tensor.__add__,
        ]:
            assert len(args) == 2
            if not (isinstance(args[0], cls) and isinstance(args[1], cls)):
                raise NotImplementedError(
                    f"{func} with {type(args[0])} and {type(args[1])} not implemented"
                )
            return cls._binary_op(func, args[0], args[1])

        if func in [
            torch.Tensor.mul,
            torch.mul,
            torch.Tensor.__mul__,
        ]:
            assert len(args) == 2
            return cls._binary_op(func, args[0], args[1])

        if func in [torch.Tensor.logical_and, torch.logical_and, torch.Tensor.__and__]:
            assert len(args) == 2
            return cls._binary_op_slow(func, args[0], args[1])

        if func in [torch.nn.functional.dropout, torch.dropout, torch.dropout_]:
            x = args[0]
            values = x.__values.clone()
            values = func(values, *args[1:], **kwargs)
            return cls._wrap(
                x.shape,
                values,
                x.__row_indices,
                x.__row_offsets,
                x.__column_indices,
                x.__transp_info,
            )

        if func == torch.Tensor.to:
            # print(args, kwargs)
            assert len(args) >= 2
            return cls._to(args[0], args[1])
            # return cls._to(args[0], kwargs["device"])

        if func in [torch.Tensor.copy_]:
            assert len(args) == 2
            return cls._copy(args[0], args[1])

        if func in [torch.Tensor.equal, torch.equal]:
            assert len(args) == 2
            return cls._equal(args[0], args[1])

        if func == torch.Tensor.to_dense:
            assert len(args) == 1
            return cls._to_dense(args[0])

        if func == torch.Tensor.detach:
            x = args[0]
            return cls._wrap(
                x.shape,
                x.__values.detach(),
                x.__row_indices,
                x.__row_offsets,
                x.__column_indices,
                x.__transp_info,
            )

        if func == torch.Tensor.__deepcopy__:
            x = args[0]
            memo = args[1]
            return cls._wrap(
                x.shape,
                x.__values.__deepcopy__(memo),
                x.__row_indices.__deepcopy__(memo),
                x.__row_offsets.__deepcopy__(memo),
                x.__column_indices.__deepcopy__(memo),
                tuple(v.__deepcopy__(memo) for v in x.__transp_info),
            )

        if func in [torch.Tensor.grad.__get__, torch.Tensor._grad.__get__]:
            assert len(args) == 1
            assert len(kwargs) == 0
            x = args[0]
            return cls._wrap(
                x.shape,
                x.__values.grad,
                x.__row_indices,
                x.__row_offsets,
                x.__column_indices,
                x.__transp_info,
            )

        if func == torch.Tensor.requires_grad_:
            func(args[0].__values)

        with torch._C.DisableTorchFunction():
            ret = func(*args, **kwargs)
            # TODO: check this
            if func in torch.overrides.get_default_nowrap_functions():
                return ret
            return torch._tensor._convert(ret, cls)

        return NotImplemented

    @classmethod
    def __torch_dispatch__(cls, func, types, args, kwargs):
        return NotImplemented
