# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import random

import numpy as np
import pytest
import torch

from nemo.collections.tts.modules.transformer_2501 import (
    ConvolutionLayer,
    CrossAttention,
    PositionwiseConvFF,
    SelfAttention,
    Transformer,
    TransformerLayer,
)
from nemo.collections.tts.parts.utils.tts_dataset_utils import beta_binomial_prior_distribution


def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)


@pytest.mark.unit
class TestConvolutionLayer:
    @classmethod
    def setup_class(cls):
        cls.in_channels = 3
        cls.out_channels = 6
        cls.kernel_size = 3
        cls.stride = 1
        cls.dilation = 1
        cls.bias = True
        # fmt:off
        cls.input_tensor = torch.Tensor(
            [[[-1.0542,  0.2675,  0.6963,  0.4738,  0.3910, -0.1505,  0.9171, -0.1528,  3.7269,  0.1779],
              [-1.0317,  1.6818,  1.4257, -0.5003, -1.7254,  0.8830, -0.4541, -0.4631, -0.0986,  0.5083],
              [-0.3231, -1.0899,  0.5774,  0.1661,  0.9620, -2.3307, -0.6158, -0.3663,  1.2469, -1.0208]]]
        )
        # fmt:on

    def test_non_causal_forward(self):
        set_seed(0)
        layer = ConvolutionLayer(
            self.in_channels,
            self.out_channels,
            self.kernel_size,
            self.stride,
            dilation=self.dilation,
            bias=self.bias,
            is_causal=False,
        )

        with torch.no_grad():
            output_tensor = layer(self.input_tensor)

        # fmt:off
        expected_output_tensor = torch.Tensor(
            [[[ 0.1912, -0.0555, -0.2681, -0.2289,  1.0788, -0.3908,  0.0936, -0.7962,  1.3754, -0.0731],
              [-0.3715, -0.6326, -0.9596, -0.0933, -0.1024, -0.2082, -0.5924, 0.1097, -0.5418, -0.0854],
              [ 0.3974,  0.4537,  0.3299,  0.1471, -0.5983, -0.8645,  0.0975, 0.6063, -0.6619, -0.9711],
              [-0.3048,  0.3862, -0.2462, -0.9903, -0.6189,  0.7389,  0.0785, -1.0870, -1.0018, -1.2426],
              [-0.4357, -0.0446,  0.0879,  0.0930, -0.2242,  0.5285,  0.4006, -0.1846,  0.5668, -0.5242],
              [-0.0625,  0.4123, -0.6289, -0.4317,  0.1595,  0.0386, -1.0774, 0.2218,  0.8483, -0.4886]]]
        )
        # fmt:on

        assert torch.allclose(output_tensor, expected_output_tensor, atol=1e-4)

    def test_causal_forward(self):
        set_seed(0)
        layer = ConvolutionLayer(
            self.in_channels,
            self.out_channels,
            self.kernel_size,
            self.stride,
            dilation=self.dilation,
            bias=self.bias,
            is_causal=True,
        )

        with torch.no_grad():
            output_tensor = layer(self.input_tensor)

        # fmt:off
        expected_output_tensor = torch.Tensor(
            [[[ 0.4301,  0.1912, -0.0555, -0.2681, -0.2289,  1.0788, -0.3908, 0.0936, -0.7962,  1.3754],
              [-0.0501, -0.3715, -0.6326, -0.9596, -0.0933, -0.1024, -0.2082, -0.5924,  0.1097, -0.5418],
              [-0.4204,  0.3974,  0.4537,  0.3299,  0.1471, -0.5983, -0.8645, 0.0975,  0.6063, -0.6619],
              [ 0.1543, -0.3048,  0.3862, -0.2462, -0.9903, -0.6189,  0.7389, 0.0785, -1.0870, -1.0018],
              [-0.1337, -0.4357, -0.0446,  0.0879,  0.0930, -0.2242,  0.5285, 0.4006, -0.1846,  0.5668],
              [-0.6127, -0.0625,  0.4123, -0.6289, -0.4317,  0.1595,  0.0386, -1.0774,  0.2218,  0.8483]]]
        )
        # fmt:on

        assert torch.allclose(output_tensor, expected_output_tensor, atol=1e-4)


class TestPositionwiseConvFF:
    @classmethod
    def setup_class(cls):
        cls.d_model = 3
        cls.d_ffn = 12
        cls.p_dropout = 0.0
        cls.kernel_size = 3
        cls.bias = True
        # fmt:off
        cls.input_tensor = torch.Tensor(
            [[[-1.6682, -0.6069,  0.1321],
              [-1.5489,  0.3279, -0.9159],
              [-0.7490,  1.8984,  0.5030],
              [-0.8130,  0.0058, -1.9979],
              [-1.4994, -0.3270,  1.4961],
              [-1.6613, -1.7827,  0.8932],
              [-0.6276, -1.0770, -0.9971],
              [ 1.5424,  1.3590,  1.2287],
              [-0.1543,  0.3365,  1.7475],
              [-0.1753,  0.4115,  0.0772]]]
        )
        # fmt:on

    def test_causal_forward(self):
        set_seed(0)
        layer = PositionwiseConvFF(
            self.d_model, self.d_ffn, self.p_dropout, self.kernel_size, bias=self.bias, is_causal=True
        )

        with torch.no_grad():
            output_tensor = layer(self.input_tensor)

        # fmt:off
        expected_output_tensor = torch.Tensor(
            [[[-0.1242, -0.0114,  0.0212],
              [-0.0441, -0.0555, -0.0795],
              [-0.0282,  0.0366, -0.2033],
              [-0.0421,  0.0305, -0.2573],
              [-0.1877, -0.2492, -0.1638],
              [-0.4300, -0.1160,  0.2177],
              [-0.1652, -0.3130, -0.3329],
              [-0.1737,  0.1133, -0.1802],
              [-0.2599, -0.0381,  0.1362],
              [-0.0584, -0.2936,  0.2719]]]
        )
        # fmt:on

        assert torch.allclose(output_tensor, expected_output_tensor, atol=1e-4)

    def test_non_causal_forward(self):
        set_seed(0)
        layer = PositionwiseConvFF(
            self.d_model, self.d_ffn, self.p_dropout, self.kernel_size, bias=self.bias, is_causal=False
        )

        with torch.no_grad():
            output_tensor = layer(self.input_tensor)

        # fmt:off
        expected_output_tensor = torch.Tensor(
            [[[-0.0617, -0.0321, -0.1646],
              [-0.0421,  0.0305, -0.2573],
              [-0.1877, -0.2492, -0.1638],
              [-0.4300, -0.1160,  0.2177],
              [-0.1652, -0.3130, -0.3329],
              [-0.1737,  0.1133, -0.1802],
              [-0.2599, -0.0381,  0.1362],
              [-0.0584, -0.2936,  0.2719],
              [ 0.0361,  0.1110,  0.0441],
              [-0.0244,  0.0682,  0.0340]]]
        )
        # fmt:on

        assert torch.allclose(output_tensor, expected_output_tensor, atol=1e-4)


class TestSelfAttention:
    @classmethod
    def setup_class(cls):
        cls.n_heads = 2
        cls.d_model = 4
        cls.p_dropout = 0.0
        cls.max_length_causal_mask = 6
        # fmt:off
        cls.query_tensor = torch.Tensor(
            [[[ 0.7239, -0.2362, -0.6610, -1.3759],
              [ 1.7381,  0.0793, -1.1241,  0.9529],
              [-1.9809,  0.2217,  0.0795,  0.0307],
              [ 0.3208,  0.4485,  0.3046, -0.0704],
              [-1.4412,  0.8981,  0.1219,  0.0481],
              [ 1.7811, -0.1358,  0.6073,  0.8275]]]
        )
        # fmt:on

    def test_causal_forward(self):
        set_seed(0)
        layer = SelfAttention(
            self.n_heads,
            self.d_model,
            self.p_dropout,
            is_causal=True,
            max_length_causal_mask=self.max_length_causal_mask,
        )
        query_mask = torch.ones(1, self.max_length_causal_mask).bool()
        with torch.no_grad():
            output_tensor, attn_output = layer(self.query_tensor, query_mask)

        # fmt:off
        expected_output_tensor = torch.Tensor(
            [[[-0.2569,  0.2782, -0.0348,  0.2480],
              [-0.3949,  0.4054, -0.0876,  0.2574],
              [-0.1033,  0.0659, -0.0259,  0.1738],
              [-0.1485,  0.0995, -0.0415,  0.0684],
              [-0.0123,  0.0185, -0.0027,  0.0708],
              [-0.0672,  0.0566, -0.0214, -0.0021]]]
        )
        expected_attn_prob = torch.Tensor(
            [[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
               [0.5828, 0.4172, 0.0000, 0.0000, 0.0000, 0.0000],
               [0.3216, 0.4260, 0.2525, 0.0000, 0.0000, 0.0000],
               [0.2385, 0.2238, 0.2872, 0.2504, 0.0000, 0.0000],
               [0.1807, 0.1973, 0.2057, 0.2045, 0.2118, 0.0000],
               [0.1159, 0.1388, 0.2010, 0.1721, 0.2161, 0.1562]],
              [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
               [0.2866, 0.7134, 0.0000, 0.0000, 0.0000, 0.0000],
               [0.2799, 0.2472, 0.4729, 0.0000, 0.0000, 0.0000],
               [0.2964, 0.2535, 0.2075, 0.2427, 0.0000, 0.0000],
               [0.1864, 0.1616, 0.2394, 0.1974, 0.2152, 0.0000],
               [0.1666, 0.2030, 0.1391, 0.1649, 0.1546, 0.1719]]]]
        )
        expected_attn_score = torch.Tensor(
            [[[[ 0.5248, float('-inf'), float('-inf'), float('-inf'), float('-inf'), float('-inf')],
               [-0.1948, -0.5291, float('-inf'), float('-inf'), float('-inf'), float('-inf')],
               [-0.0279,  0.2533, -0.2698, float('-inf'), float('-inf'), float('-inf')],
               [-0.0508, -0.1145,  0.1350, -0.0020, float('-inf'), float('-inf')],
               [-0.0985, -0.0105,  0.0315,  0.0257,  0.0604, float('-inf')],
               [-0.3253, -0.1457,  0.2250,  0.0694,  0.2971, -0.0275]],
              [[ 0.5075, float('-inf'), float('-inf'), float('-inf'), float('-inf'), float('-inf')],
               [-0.5541,  0.3578, float('-inf'), float('-inf'), float('-inf'), float('-inf')],
               [-0.2499, -0.3738,  0.2746, float('-inf'), float('-inf'), float('-inf')],
               [ 0.2215,  0.0654, -0.1351,  0.0216, float('-inf'), float('-inf')],
               [-0.1011, -0.2439,  0.1488, -0.0438,  0.0425, float('-inf')],
               [ 0.0526,  0.2502, -0.1277,  0.0424, -0.0221,  0.0840]]]]
        )
        # fmt:on

        assert torch.allclose(output_tensor, expected_output_tensor, atol=1e-4)
        assert torch.allclose(attn_output[0], expected_attn_prob, atol=1e-4)
        assert torch.allclose(attn_output[1], expected_attn_score, atol=1e-4)

    def test_non_causal_forward(self):
        set_seed(0)
        layer = SelfAttention(
            self.n_heads,
            self.d_model,
            self.p_dropout,
            is_causal=False,
            max_length_causal_mask=self.max_length_causal_mask,
        )
        query_mask = torch.ones(1, self.max_length_causal_mask).bool()
        with torch.no_grad():
            output_tensor, attn_output = layer(self.query_tensor, query_mask)

        # fmt:off
        expected_output_tensor = torch.Tensor(
            [[[-0.0954,  0.1131, -0.0195,  0.0704],
              [-0.0401,  0.0364, -0.0156,  0.0088],
              [ 0.0324,  0.0368,  0.0174,  0.0501],
              [-0.0633,  0.0610, -0.0176,  0.0180],
              [-0.0017,  0.0361,  0.0030,  0.0319],
              [-0.0672,  0.0566, -0.0214, -0.0021]]]
        )
        expected_attn_prob = torch.Tensor(
            [[[[0.2835, 0.1482, 0.1723, 0.1426, 0.1422, 0.1111],
               [0.1254, 0.0898, 0.2819, 0.1493, 0.2682, 0.0854],
               [0.1549, 0.2051, 0.1216, 0.1663, 0.1294, 0.2228],
               [0.1583, 0.1485, 0.1906, 0.1662, 0.1890, 0.1475],
               [0.1498, 0.1635, 0.1705, 0.1696, 0.1755, 0.1711],
               [0.1159, 0.1388, 0.2010, 0.1721, 0.2161, 0.1562]],
              [[0.2536, 0.1945, 0.1079, 0.1628, 0.1233, 0.1579],
               [0.0866, 0.2156, 0.1709, 0.1551, 0.1903, 0.1815],
               [0.1361, 0.1202, 0.2300, 0.1626, 0.1941, 0.1569],
               [0.2038, 0.1744, 0.1427, 0.1669, 0.1488, 0.1634],
               [0.1565, 0.1357, 0.2010, 0.1657, 0.1807, 0.1604],
               [0.1666, 0.2030, 0.1391, 0.1649, 0.1546, 0.1719]]]]
        )
        expected_attn_score = torch.Tensor(
            [[[[ 5.2482e-01, -1.2346e-01,  2.7022e-02, -1.6210e-01, -1.6488e-01, -4.1190e-01],
               [-1.9484e-01, -5.2910e-01,  6.1538e-01, -2.0263e-02,  5.6540e-01, -5.7875e-01],
               [-2.7873e-02,  2.5326e-01, -2.6980e-01,  4.3127e-02, -2.0759e-01, 3.3584e-01],
               [-5.0756e-02, -1.1455e-01,  1.3498e-01, -2.0208e-03,  1.2687e-01, -1.2113e-01],
               [-9.8482e-02, -1.0463e-02,  3.1513e-02,  2.5712e-02,  6.0431e-02, 3.4494e-02],
               [-3.2530e-01, -1.4574e-01,  2.2503e-01,  6.9375e-02,  2.9711e-01, -2.7542e-02]],
              [[ 5.0748e-01,  2.4198e-01, -3.4704e-01,  6.4137e-02, -2.1347e-01, 3.3762e-02],
               [-5.5410e-01,  3.5778e-01,  1.2559e-01,  2.8689e-02,  2.3316e-01, 1.8558e-01],
               [-2.4989e-01, -3.7383e-01,  2.7462e-01, -7.2002e-02,  1.0508e-01, -1.0771e-01],
               [ 2.2154e-01,  6.5375e-02, -1.3510e-01,  2.1609e-02, -9.3194e-02, 3.4042e-04],
               [-1.0105e-01, -2.4395e-01,  1.4884e-01, -4.3842e-02,  4.2481e-02, -7.6735e-02],
               [ 5.2595e-02,  2.5018e-01, -1.2765e-01,  4.2375e-02, -2.2093e-02, 8.4005e-02]]]]
        )
        # fmt:on

        assert torch.allclose(output_tensor, expected_output_tensor, atol=1e-4)
        assert torch.allclose(attn_output[0], expected_attn_prob, atol=1e-4)
        assert torch.allclose(attn_output[1], expected_attn_score, atol=1e-5)


class TestCrossAttention:
    @classmethod
    def setup_class(cls):
        cls.n_heads = 2
        cls.d_model = 4
        cls.d_memory = 3
        cls.p_dropout = 0.0
        cls.max_length = 6
        # fmt:off
        # shape = (1, cls.max_length, cls.d_model)
        cls.query_tensor = torch.Tensor(
            [[[0.7352, -0.5871, -0.1204, -2.0200],
              [-0.4618, -0.3604, 1.2287, -0.3434],
              [0.7838, -0.7646, -1.3349, -0.1538],
              [-0.9749, -1.0789, -0.0126, -0.7225],
              [2.6929, -0.2091, 2.1242, -1.0123],
              [0.5094, -2.0566, 1.3922, -0.2156]]]
        )
        # fmt:on
        cls.query_mask = torch.ones(1, cls.query_tensor.shape[1]).bool()
        # fmt:off
        # shape = (1, 5, cls.d_memory)
        cls.memory_tensor = torch.Tensor(
            [[[ 2.0132e-01, -5.6582e-01,  1.1191e+00],
              [-6.2371e-01, -9.3398e-02, -1.3744e+00],
              [-9.8265e-01, -8.1742e-01,  4.5611e-01],
              [-5.4802e-01, -1.1218e+00,  7.6138e-01],
              [-1.9899e+00, -1.7910e-03,  9.0718e-01]]]
        )
        # fmt:on
        cls.memory_mask = torch.ones(1, cls.memory_tensor.shape[1]).bool()
        # shape = (1, cls.query_tensor.shape[1], cls.memory_tensor.shape[1])
        cls.attn_prior = torch.from_numpy(
            beta_binomial_prior_distribution(
                phoneme_count=cls.memory_tensor.shape[1], mel_count=cls.query_tensor.shape[1]
            )
        ).unsqueeze(0)

    def test_forward(self):
        set_seed(0)
        layer = CrossAttention(self.n_heads, self.d_model, self.d_memory, self.p_dropout)

        with torch.no_grad():
            output_tensor, attn_output = layer(
                self.query_tensor, self.query_mask, self.memory_tensor, self.memory_mask, self.attn_prior
            )

        # fmt:off
        expected_output_tensor = torch.Tensor(
            [[[ 0.2267, -0.2271,  0.0573, -0.0681],
              [ 0.2672, -0.1823,  0.0722, -0.0859],
              [ 0.3212, -0.2218,  0.0835, -0.0715],
              [ 0.3568, -0.2573,  0.0918, -0.0789],
              [ 0.3962, -0.4112,  0.0816, -0.1972],
              [ 0.3457, -0.4253,  0.0568, -0.2216]]]
        )
        expected_attn_prob = torch.Tensor(
            [[[[0.4220, 0.4859, 0.0709, 0.0188, 0.0025],
               [0.3944, 0.3475, 0.1642, 0.0784, 0.0155],
               [0.1335, 0.3448, 0.2794, 0.1752, 0.0671],
               [0.0914, 0.3300, 0.2343, 0.2437, 0.1006],
               [0.0256, 0.1138, 0.2145, 0.3343, 0.3119],
               [0.0117, 0.0617, 0.1112, 0.3354, 0.4800]],
              [[0.8045, 0.1024, 0.0661, 0.0242, 0.0028],
               [0.4020, 0.2953, 0.1914, 0.0907, 0.0207],
               [0.1446, 0.2798, 0.3026, 0.1965, 0.0766],
               [0.0718, 0.2151, 0.2778, 0.2719, 0.1634],
               [0.0673, 0.0341, 0.1929, 0.4534, 0.2522],
               [0.0064, 0.0264, 0.0999, 0.2872, 0.5802]]]]
        )
        expected_attn_score = torch.Tensor(
            [[[[-0.5044,  0.4476, -0.4961, -0.5728, -0.8010],
               [-0.0761, -0.2027, -0.5103, -0.4385, -0.6724],
               [-0.1525,  0.2576,  0.0471, -0.0138,  0.0075],
               [-0.3084, -0.0058, -0.7538, -0.7144, -1.0604],
               [-0.0799,  0.0260, -0.1511, -0.1493, -0.2187],
               [-0.1935, -0.3225, -0.9867, -0.8637, -1.3162]],
              [[ 0.4704, -0.7801, -0.2374,  0.0126, -0.3430],
               [ 0.0623, -0.2461, -0.2380, -0.1743, -0.2658],
               [ 0.0104,  0.1313,  0.2096,  0.1834,  0.2217],
               [-0.0859,  0.0300, -0.1194, -0.1410, -0.1111],
               [ 0.7876, -1.2782, -0.3570,  0.0556, -0.5312],
               [ 0.1046, -0.2652, -0.1856, -0.1104, -0.2180]]]]
        )
        # fmt:on

        assert torch.allclose(output_tensor, expected_output_tensor, atol=1e-4)
        assert torch.allclose(attn_output[0], expected_attn_prob, atol=1e-4)
        assert torch.allclose(attn_output[1], expected_attn_score, atol=1e-4)


class TestTransformerLayer:
    @classmethod
    def setup_class(cls):
        cls.d_model = 2
        cls.d_ffn = 8
        cls.sa_n_heads = 2
        cls.kernel_size = 3
        cls.p_dropout = 0.0
        cls.max_length_causal_mask = 5
        # fmt:off
        # shape = (1, cls.max_length_causal_mask, cls.d_model)
        cls.x = torch.Tensor(
            [[[ 0.5115,  0.0889],
              [-0.8568, -2.9632],
              [-1.3728,  0.7325],
              [-2.4593, -0.9018],
              [ 0.9621,  0.4212]]]
        )
        # fmt:on
        cls.x_mask = torch.ones(1, cls.max_length_causal_mask).bool()
        # fmt:off
        # shape = (1, 3, cls.d_model)
        cls.cond = torch.Tensor(
            [[[ 1.4441,  0.1393],
              [ 0.2828, -0.2456],
              [-0.3075,  0.6581]]]
        )
        # fmt:on
        cls.cond_mask = torch.ones(1, cls.cond.shape[1]).bool()
        # shape = (1, cls.x.shape[1], cls.cond.shape[1])
        cls.attn_prior = torch.from_numpy(
            beta_binomial_prior_distribution(phoneme_count=cls.cond.shape[1], mel_count=cls.x.shape[1])
        ).unsqueeze(0)

    def test_forward_causal_self_attn_and_has_xattn(self):
        set_seed(0)
        layer = TransformerLayer(
            self.d_model,
            self.d_ffn,
            self.sa_n_heads,
            self.kernel_size,
            self.p_dropout,
            has_xattn=True,
            xa_n_heads=2,
            xa_d_memory=2,
            is_causal=True,
            max_length_causal_mask=self.max_length_causal_mask,
        )

        with torch.no_grad():
            output_dict = layer(self.x, self.x_mask, self.cond, self.cond_mask, self.attn_prior)

        # fmt:off
        expected_output = {
            'output': torch.Tensor(
                [[[ 0.1936,  0.5387],
                  [-1.0270, -2.5452],
                  [-1.6884,  0.8765],
                  [-2.7496, -0.7887],
                  [ 1.2837,  0.1172]]]
            ),
            'attn_probabilities': {
                'self_attn_probabilities': [
                    torch.Tensor(
                        [[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
                           [0.5000, 0.5000, 0.0000, 0.0000, 0.0000],
                           [0.3068, 0.3068, 0.3864, 0.0000, 0.0000],
                           [0.2213, 0.2213, 0.2787, 0.2787, 0.0000],
                           [0.2180, 0.2180, 0.1730, 0.1730, 0.2180]],
                          [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
                           [0.5000, 0.5000, 0.0000, 0.0000, 0.0000],
                           [0.3237, 0.3237, 0.3527, 0.0000, 0.0000],
                           [0.2393, 0.2393, 0.2607, 0.2607, 0.0000],
                           [0.2068, 0.2068, 0.1898, 0.1898, 0.2068]]]]
                    ),
                    torch.Tensor(
                        [[[[0.1154, float('-inf'), float('-inf'), float('-inf'), float('-inf')],
                           [0.1154, 0.1154, float('-inf'), float('-inf'), float('-inf')],
                           [-0.1154, -0.1154, 0.1154, float('-inf'), float('-inf')],
                           [-0.1154, -0.1154, 0.1154, 0.1154, float('-inf')],
                           [0.1154, 0.1154, -0.1154, -0.1154, 0.1154]],
                          [[0.0429, float('-inf'), float('-inf'), float('-inf'), float('-inf')],
                           [0.0429, 0.0429, float('-inf'), float('-inf'), float('-inf')],
                           [-0.0429, -0.0429, 0.0429, float('-inf'), float('-inf')],
                           [-0.0429, -0.0429, 0.0429, 0.0429, float('-inf')],
                           [0.0429, 0.0429, -0.0429, -0.0429, 0.0429]]]]
                    )
                ],
                'cross_attn_probabilities': [
                    torch.Tensor(
                        [[[[0.7181, 0.2394, 0.0426],
                           [0.4843, 0.3874, 0.1283],
                           [0.2753, 0.4129, 0.3118],
                           [0.1344, 0.3583, 0.5074],
                           [0.0520, 0.2599, 0.6882]],
                          [[0.5959, 0.1987, 0.2054],
                           [0.2837, 0.2270, 0.4893],
                           [0.3740, 0.5610, 0.0651],
                           [0.2355, 0.6280, 0.1365],
                           [0.0108, 0.0542, 0.9349]]]]
                    ),
                    torch.Tensor(
                        [[[[0.0586, 0.0586, -0.0586],
                           [0.0624, 0.0624, -0.0624],
                           [-0.0624, -0.0624, 0.0624],
                           [-0.0624, -0.0624, 0.0624],
                           [0.0624, 0.0624, -0.0624]],
                          [[-0.8214, -0.8214, 0.8214],
                           [-0.8745, -0.8744, 0.8745],
                           [0.8745, 0.8744, -0.8745],
                           [0.8745, 0.8744, -0.8745],
                           [-0.8744, -0.8744, 0.8744]]]]
                    )
                ]
            }
        }
        # fmt:on

        assert torch.allclose(output_dict["output"], expected_output["output"], atol=1e-4)
        for i in range(2):
            assert torch.allclose(
                output_dict["attn_probabilities"]["self_attn_probabilities"][i],
                expected_output["attn_probabilities"]["self_attn_probabilities"][i],
                atol=1e-4,
            )
            assert torch.allclose(
                output_dict["attn_probabilities"]["cross_attn_probabilities"][i],
                expected_output["attn_probabilities"]["cross_attn_probabilities"][i],
                atol=1e-4,
            )


@pytest.mark.unit
class TestTransformer:
    @classmethod
    def setup_class(cls):
        cls.n_layers = 1
        cls.d_model = 4
        cls.d_ffn = 16
        cls.sa_n_heads = 2
        cls.kernel_size = 3
        cls.p_dropout = 0.0
        cls.p_dropout_out = 0.0
        cls.is_causal = True
        cls.max_length_causal_mask = 6

        # fmt:off
        cls.input_tensor = torch.Tensor(
            [[[ 0.7049,  0.0305, -0.8542,  0.5388],
              [-0.5265, -1.3320,  1.5451,  0.4086],
              [-2.0546,  0.5259,  0.5995, -0.4078],
              [ 0.4530, -0.3918,  2.1403, -0.2062],
              [-0.0984,  0.4855,  0.7076,  0.0431],
              [-0.4394, -0.6761,  1.7389, -0.9423]]]
        )
        # fmt:on

    def test_forward_causal_self_attn_and_no_xattn(self):
        set_seed(0)
        model = Transformer(
            n_layers=self.n_layers,
            d_model=self.d_model,
            d_ffn=self.d_ffn,
            sa_n_heads=self.sa_n_heads,
            kernel_size=self.kernel_size,
            p_dropout=self.p_dropout,
            p_dropout_out=self.p_dropout_out,
            has_xattn=False,
            is_causal=self.is_causal,
            max_length_causal_mask=self.max_length_causal_mask,
        )

        # Check model init
        assert torch.isclose(torch.mean(model.layers[0].pos_ff.proj.conv.weight), torch.tensor(0.0), atol=1e-2)
        assert torch.isclose(torch.std(model.layers[0].pos_ff.proj.conv.weight), torch.tensor(0.02), atol=1e-2)
        assert torch.isclose(torch.mean(model.layers[0].pos_ff.o_net.conv.weight), torch.tensor(0.0), atol=1e-2)
        assert torch.isclose(
            torch.std(model.layers[0].pos_ff.o_net.conv.weight), torch.tensor(0.02 / math.sqrt(2.0)), atol=1e-3
        )

        mask_tensor = torch.ones(1, self.max_length_causal_mask).bool()
        with torch.no_grad():
            output_dict = model(x=self.input_tensor, x_mask=mask_tensor)

        # fmt:off
        expected_output_tensor = {
            'output': torch.Tensor(
                [[[0.7047, 0.0305, -0.8555, 0.5402],
                  [-0.5192, -1.3324, 1.5455, 0.4148],
                  [-2.0593, 0.5290, 0.5969, -0.4101],
                  [0.4517, -0.3968, 2.1392, -0.2041],
                  [-0.1019, 0.4854, 0.7077, 0.0431],
                  [-0.4458, -0.6789, 1.7447, -0.9447]]]
            ),
            'attn_probabilities': [
                {
                    'self_attn_probabilities': [
                        torch.Tensor(
                            [[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
                               [0.4998, 0.5002, 0.0000, 0.0000, 0.0000, 0.0000],
                               [0.3337, 0.3333, 0.3330, 0.0000, 0.0000, 0.0000],
                               [0.2498, 0.2500, 0.2501, 0.2501, 0.0000, 0.0000],
                               [0.2002, 0.2001, 0.2000, 0.1999, 0.1999, 0.0000],
                               [0.1666, 0.1666, 0.1667, 0.1667, 0.1667, 0.1667]],
                              [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
                               [0.5005, 0.4995, 0.0000, 0.0000, 0.0000, 0.0000],
                               [0.3332, 0.3331, 0.3336, 0.0000, 0.0000, 0.0000],
                               [0.2507, 0.2494, 0.2510, 0.2489, 0.0000, 0.0000],
                               [0.2002, 0.1995, 0.2006, 0.1994, 0.2003, 0.0000],
                               [0.1671, 0.1663, 0.1674, 0.1660, 0.1670, 0.1662]]]]
                        ),
                        torch.Tensor(
                            [[[[-3.4823e-04, float('-inf'), float('-inf'), float('-inf'), float('-inf'), float('-inf')],
                               [-7.0210e-04, -5.6984e-05, float('-inf'), float('-inf'), float('-inf'), float('-inf')],
                               [1.2551e-03, 1.1431e-04, -5.9334e-04, float('-inf'), float('-inf'), float('-inf')],
                               [-8.1514e-04, 2.6650e-05, 5.2952e-04, 5.6903e-04, float('-inf'), float('-inf')],
                               [8.0150e-04, 1.4366e-04, -2.7793e-04, -7.0636e-04, -8.2140e-04, float('-inf')],
                               [-4.6137e-04, 6.2648e-05, 3.6768e-04, 2.8095e-04, 4.9188e-04, 3.6888e-04]],
                              [[-7.5861e-04, float('-inf'), float('-inf'), float('-inf'), float('-inf'), float('-inf')],
                               [8.0373e-04, -1.2745e-03, float('-inf'), float('-inf'), float('-inf'), float('-inf')],
                               [-4.5038e-04, -6.7648e-04, 7.5978e-04, float('-inf'), float('-inf'), float('-inf')],
                               [1.5376e-03, -3.6984e-03, 3.0423e-03, -5.4870e-03, float('-inf'), float('-inf')],
                               [3.7014e-04, -2.7010e-03, 2.4310e-03, -3.2604e-03, 9.3840e-04, float('-inf')],
                               [1.3868e-03, -3.8372e-03, 3.2144e-03, -5.4860e-03, 5.0273e-04, -4.4343e-03]]]]
                        ),
                    ],
                    'cross_attn_probabilities': None,
                }
            ],
        }
        # fmt:on

        assert output_dict["output"].shape == expected_output_tensor["output"].shape
        assert torch.allclose(output_dict["output"], expected_output_tensor["output"], atol=1e-4)
        for i in range(2):
            assert torch.allclose(
                output_dict["attn_probabilities"][0]["self_attn_probabilities"][i],
                expected_output_tensor["attn_probabilities"][0]["self_attn_probabilities"][i],
                atol=1e-4,
            )
        assert output_dict["attn_probabilities"][0]["cross_attn_probabilities"] is None

    def test_forward_causal_self_attn_and_has_xattn(self):
        set_seed(0)
        model = Transformer(
            n_layers=2,
            d_model=self.d_model,
            d_ffn=self.d_ffn,
            sa_n_heads=self.sa_n_heads,
            kernel_size=self.kernel_size,
            p_dropout=self.p_dropout,
            p_dropout_out=self.p_dropout_out,
            has_xattn=True,
            xa_d_memory=4,
            xa_n_heads=2,
            is_causal=self.is_causal,
            max_length_causal_mask=self.max_length_causal_mask,
        )

        # fmt:off
        cond = [
            # shape (1, 3, 4)
            torch.Tensor(
                [[[-0.7475,  1.1461,  0.7300,  1.4471],
                  [ 1.8744, -0.1654,  1.2418, -1.6983],
                  [-0.3123,  0.2320,  0.7457,  1.9868]]]
            ),
            # shape (1, 5, 4)
            torch.Tensor(
                [[[-0.6683, -1.2178,  1.3696,  0.9941],
                  [ 0.0297, -0.1616,  0.1891,  0.0580],
                  [-1.0771,  0.2547, -1.4023,  0.0971],
                  [ 1.1132,  0.6311, -0.1449,  0.2351],
                  [ 0.8920,  2.3663,  0.2248, -0.7298]]]
            )
        ]
        # fmt:on

        cond_mask = [torch.ones(1, cond[0].shape[1]).bool(), torch.ones(1, cond[1].shape[1]).bool()]
        mask_tensor = torch.ones(1, self.max_length_causal_mask).bool()
        multi_encoder_mapping = [0, 1]
        with torch.no_grad():
            output_dict = model(
                x=self.input_tensor,
                x_mask=mask_tensor,
                cond=cond,
                cond_mask=cond_mask,
                multi_encoder_mapping=multi_encoder_mapping,
            )

        # fmt:off
        expected_output = {
            'output': torch.Tensor(
                [[[0.7043, 0.0288, -0.8547, 0.5384],
                  [-0.5283, -1.3311, 1.5429, 0.4083],
                  [-2.0560, 0.5259, 0.6020, -0.4099],
                  [0.4554, -0.3829, 2.1433, -0.2036],
                  [-0.0986, 0.4794, 0.7067, 0.0432],
                  [-0.4392, -0.6772, 1.7428, -0.9393]]]
            ),
            'attn_probabilities': [
                {
                    'self_attn_probabilities': [
                        torch.Tensor(
                            [[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
                               [0.4989, 0.5011, 0.0000, 0.0000, 0.0000, 0.0000],
                               [0.3331, 0.3332, 0.3336, 0.0000, 0.0000, 0.0000],
                               [0.2495, 0.2496, 0.2504, 0.2505, 0.0000, 0.0000],
                               [0.1998, 0.1994, 0.2002, 0.2001, 0.2005, 0.0000],
                               [0.1662, 0.1662, 0.1668, 0.1668, 0.1671, 0.1669]],
                              [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
                               [0.5008, 0.4992, 0.0000, 0.0000, 0.0000, 0.0000],
                               [0.3334, 0.3336, 0.3331, 0.0000, 0.0000, 0.0000],
                               [0.2504, 0.2500, 0.2496, 0.2500, 0.0000, 0.0000],
                               [0.2001, 0.2002, 0.2000, 0.1999, 0.1998, 0.0000],
                               [0.1670, 0.1667, 0.1665, 0.1667, 0.1665, 0.1666]]]]
                        ),
                    ],
                    'cross_attn_probabilities': [
                        torch.Tensor(
                            [[[[0.3331, 0.3336, 0.3334],
                               [0.3335, 0.3331, 0.3334],
                               [0.3336, 0.3331, 0.3332],
                               [0.3335, 0.3332, 0.3334],
                               [0.3336, 0.3331, 0.3333],
                               [0.3335, 0.3332, 0.3333]],
                              [[0.3333, 0.3335, 0.3332],
                               [0.3334, 0.3335, 0.3331],
                               [0.3333, 0.3330, 0.3337],
                               [0.3334, 0.3335, 0.3332],
                               [0.3333, 0.3331, 0.3336],
                               [0.3334, 0.3334, 0.3333]]]]
                        )
                    ]
                },
                {
                    'self_attn_probabilities': [
                        torch.Tensor(
                            [[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
                               [0.5005, 0.4995, 0.0000, 0.0000, 0.0000, 0.0000],
                               [0.3336, 0.3330, 0.3334, 0.0000, 0.0000, 0.0000],
                               [0.2503, 0.2499, 0.2498, 0.2500, 0.0000, 0.0000],
                               [0.2002, 0.1999, 0.2000, 0.2000, 0.2000, 0.0000],
                               [0.1669, 0.1666, 0.1666, 0.1667, 0.1666, 0.1666]],
                              [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
                               [0.5001, 0.4999, 0.0000, 0.0000, 0.0000, 0.0000],
                               [0.3329, 0.3330, 0.3340, 0.0000, 0.0000, 0.0000],
                               [0.2499, 0.2498, 0.2505, 0.2499, 0.0000, 0.0000],
                               [0.1997, 0.1997, 0.2003, 0.2000, 0.2004, 0.0000],
                               [0.1665, 0.1664, 0.1669, 0.1666, 0.1669, 0.1667]]]]
                        ),
                    ],
                    'cross_attn_probabilities': [
                        torch.Tensor(
                            [[[[0.1999, 0.1998, 0.2002, 0.1999, 0.2001],
                               [0.2000, 0.1997, 0.2004, 0.1998, 0.2002],
                               [0.2001, 0.2000, 0.2001, 0.1999, 0.2000],
                               [0.2000, 0.2001, 0.1998, 0.2001, 0.1999],
                               [0.2001, 0.2002, 0.1998, 0.2001, 0.1998],
                               [0.2000, 0.2002, 0.1998, 0.2001, 0.1999]],
                              [[0.1998, 0.1998, 0.2001, 0.2004, 0.2000],
                               [0.2003, 0.2003, 0.1998, 0.1995, 0.2001],
                               [0.2003, 0.2003, 0.1998, 0.1995, 0.2001],
                               [0.2001, 0.2001, 0.2000, 0.1998, 0.2000],
                               [0.2002, 0.2001, 0.1999, 0.1997, 0.2000],
                               [0.2002, 0.2001, 0.2000, 0.1997, 0.2000]]]]
                        ),
                    ],
                }
            ],
        }
        # fmt:on

        assert torch.allclose(output_dict["output"], expected_output["output"], atol=1e-4)
        for i in range(2):
            assert torch.allclose(
                output_dict["attn_probabilities"][i]["self_attn_probabilities"][0],
                expected_output["attn_probabilities"][i]["self_attn_probabilities"][0],
                atol=1e-4,
            )
            assert torch.allclose(
                output_dict["attn_probabilities"][i]["cross_attn_probabilities"][0],
                expected_output["attn_probabilities"][i]["cross_attn_probabilities"][0],
                atol=1e-4,
            )
