# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import numpy as np
import pytest
import torch
from scipy.optimize import linear_sum_assignment as scipy_linear_sum_assignment

from nemo.collections.asr.data.audio_to_label import repeat_signal
from nemo.collections.asr.parts.utils.longform_clustering import LongFormSpeakerClustering
from nemo.collections.asr.parts.utils.offline_clustering import (
    SpeakerClustering,
    get_scale_interpolated_embs,
    getCosAffinityMatrix,
    getKneighborsConnections,
    split_input_data,
)
from nemo.collections.asr.parts.utils.online_clustering import (
    OnlineSpeakerClustering,
    get_closest_embeddings,
    get_merge_quantity,
    get_minimal_indices,
    merge_vectors,
    run_reducer,
    stitch_cluster_labels,
)
from nemo.collections.asr.parts.utils.optimization_utils import LinearSumAssignmentSolver
from nemo.collections.asr.parts.utils.optimization_utils import linear_sum_assignment as nemo_linear_sum_assignment
from nemo.collections.asr.parts.utils.speaker_utils import (
    OnlineSegmentor,
    check_ranges,
    fl2int,
    get_new_cursor_for_update,
    get_online_segments_from_slices,
    get_online_subsegments_from_buffer,
    get_speech_labels_for_update,
    get_sub_range_list,
    get_subsegments,
    get_subsegments_scriptable,
    get_target_sig,
    int2fl,
    is_overlap,
    merge_float_intervals,
    merge_int_intervals,
    tensor_to_list,
)


def check_range_values(target, source):
    bool_list = []
    for tgt, src in zip(target, source):
        for x, y in zip(src, tgt):
            bool_list.append(abs(x - y) < 1e-6)
    return all(bool_list)


def check_labels(target, source):
    bool_list = []
    for x, y in zip(target, source):
        bool_list.append(abs(x - y) < 1e-6)
    return all(bool_list)


def matrix(mat, use_tensor=True, dtype=torch.long):
    if use_tensor:
        mat = torch.Tensor(mat).to(dtype)
    else:
        mat = np.array(mat)
    return mat


def generate_orthogonal_embs(total_spks, perturb_sigma, emb_dim):
    """Generate a set of artificial orthogonal embedding vectors from random numbers"""
    gaus = torch.randn(emb_dim, emb_dim)
    _svd = torch.linalg.svd(gaus)
    orth = _svd[0] @ _svd[2]
    orth_embs = orth[:total_spks]
    # Assert orthogonality
    assert torch.abs(getCosAffinityMatrix(orth_embs) - torch.diag(torch.ones(total_spks))).sum() < 1e-4
    return orth_embs


def generate_toy_data(
    n_spks=2,
    spk_dur=3,
    emb_dim=192,
    perturb_sigma=0.0,
    ms_window=[1.5, 1.0, 0.5],
    ms_shift=[0.75, 0.5, 0.25],
    torch_seed=0,
):
    torch.manual_seed(torch_seed)
    spk_timestamps = [(spk_dur * k, spk_dur) for k in range(n_spks)]
    emb_list, seg_list = [], []
    multiscale_segment_counts = [0 for _ in range(len(ms_window))]
    ground_truth = []
    random_orthogonal_embs = generate_orthogonal_embs(n_spks, perturb_sigma, emb_dim)
    for scale_idx, (window, shift) in enumerate(zip(ms_window, ms_shift)):
        for spk_idx, (offset, dur) in enumerate(spk_timestamps):
            segments_stt_dur = get_subsegments_scriptable(offset=offset, window=window, shift=shift, duration=dur)
            segments = [[x[0], x[0] + x[1]] for x in segments_stt_dur]
            emb_cent = random_orthogonal_embs[spk_idx, :]
            emb = emb_cent.tile((len(segments), 1)) + 0.1 * torch.rand(len(segments), emb_dim)
            seg_list.extend(segments)
            emb_list.append(emb)
            if emb.shape[0] == 0:
                import ipdb

                ipdb.set_trace()
            multiscale_segment_counts[scale_idx] += emb.shape[0]

            if scale_idx == len(multiscale_segment_counts) - 1:
                ground_truth.extend([spk_idx] * emb.shape[0])

    emb_tensor = torch.concat(emb_list)
    multiscale_segment_counts = torch.tensor(multiscale_segment_counts)
    segm_tensor = torch.tensor(seg_list)
    multiscale_weights = torch.ones(len(ms_window)).unsqueeze(0)
    ground_truth = torch.tensor(ground_truth)
    return emb_tensor, segm_tensor, multiscale_segment_counts, multiscale_weights, spk_timestamps, ground_truth


class TestDiarizationSequneceUtilFunctions:
    """Tests diarization and speaker-task related utils."""

    @pytest.mark.unit
    @pytest.mark.parametrize("Y", [[3, 3, 3, 4, 4, 5], [100, 100, 100, 104, 104, 1005]])
    @pytest.mark.parametrize("target", [[0, 0, 0, 1, 1, 2]])
    @pytest.mark.parametrize("offset", [1, 10])
    def test_minimal_index_ex2(self, Y, target, offset):
        Y = torch.tensor(Y)
        target = torch.tensor(target)
        min_Y = get_minimal_indices(Y)
        assert check_labels(target, min_Y)
        min_Y = get_minimal_indices(Y + offset)
        assert check_labels(target, min_Y)

    @pytest.mark.parametrize("Y", [[4, 0, 0, 5, 4, 5], [14, 12, 12, 19, 14, 19]])
    @pytest.mark.parametrize("target", [[1, 0, 0, 2, 1, 2]])
    @pytest.mark.parametrize("offset", [1, 10])
    def test_minimal_index_ex2(self, Y, target, offset):
        Y = torch.tensor(Y)
        target = torch.tensor(target)
        min_Y = get_minimal_indices(Y)
        assert check_labels(target, min_Y)
        min_Y = get_minimal_indices(Y + offset)
        assert check_labels(target, min_Y)

    @pytest.mark.unit
    @pytest.mark.parametrize("N", [2, 4, 16, 64])
    def test_minimal_index_same(self, N):
        Y = matrix([0] * N + [1] * N + [2] * N)
        min_Y = get_minimal_indices(Y)
        target = matrix([0] * N + [1] * N + [2] * N)
        assert check_labels(target, min_Y)

    @pytest.mark.unit
    @pytest.mark.parametrize("N", [2, 4, 16, 64])
    def test_stitch_cluster_labels_label_switch(self, N):
        Y_old = matrix([0] * N)
        Y_new = matrix([0] * N) + 1
        target = matrix([0] * N)
        result = stitch_cluster_labels(Y_old, Y_new)
        assert check_labels(target, result)

    @pytest.mark.unit
    @pytest.mark.parametrize("N", [2, 4, 16, 64])
    def test_stitch_cluster_labels_label_many_to_one(self, N):
        Y_old = matrix(np.arange(N).tolist())
        Y_new = matrix([0] * N)
        target = matrix([0] * N)
        result = stitch_cluster_labels(Y_old, Y_new)
        assert check_labels(target, result)

    @pytest.mark.unit
    @pytest.mark.parametrize("N", [2, 4, 16, 64])
    def test_stitch_cluster_labels_label_one_to_many(self, N):
        Y_old = matrix(np.arange(N).tolist())
        Y_new = matrix([k for k in range(N)])
        target = matrix([k for k in range(N)])
        result = stitch_cluster_labels(Y_old, Y_new)
        assert check_labels(target, result)

    @pytest.mark.unit
    @pytest.mark.parametrize("N", [2, 4, 16, 64])
    def test_stitch_cluster_labels_one_label_replaced(self, N):
        Y_old = matrix([0] * N + [1] * N + [2] * N)
        Y_new = matrix([1] * N + [2] * N + [3] * N)
        target = matrix([0] * N + [1] * N + [2] * N)
        result = stitch_cluster_labels(Y_old, Y_new)
        assert check_labels(target, result)

    @pytest.mark.unit
    @pytest.mark.parametrize("N", [2, 4, 16, 64])
    def test_stitch_cluster_labels_confusion_error(self, N):
        Y_old = matrix([0] * N + [1] * (N - 1) + [2] * (N + 1))
        Y_new = matrix([1] * N + [2] * N + [3] * N)
        target = matrix([0] * N + [1] * N + [2] * N)
        result = stitch_cluster_labels(Y_old, Y_new)
        assert check_labels(target, result)

    @pytest.mark.unit
    @pytest.mark.parametrize("N", [2, 256])
    def test_stitch_cluster_labels_speaker_more_speakers(self, N):
        Y_old = matrix([0] * N + [1] * (N - 1) + [2] * (N + 1) + [0, 0, 0])
        Y_new = matrix([1] * N + [0] * N + [2] * N + [4, 5, 6])
        target = matrix([0] * N + [1] * N + [2] * N + [3, 4, 5])
        result = stitch_cluster_labels(Y_old, Y_new)
        assert check_labels(target, result)

    @pytest.mark.unit
    @pytest.mark.parametrize("N", [2, 256])
    def test_stitch_cluster_labels_speaker_longer_sequence(self, N):
        Y_old = matrix([0] * N + [1] * N + [2] * N + [0, 0, 0] * N)
        Y_new = matrix([1] * N + [2] * N + [0] * N + [1, 2, 3, 1, 2, 3] * N)
        target = matrix([0] * N + [1] * N + [2] * N + [0, 1, 3, 0, 1, 3] * N)
        result = stitch_cluster_labels(Y_old, Y_new)
        assert check_labels(target, result)

    @pytest.mark.unit
    @pytest.mark.parametrize("n_spks", [2, 3, 4, 5])
    @pytest.mark.parametrize("merge_quantity", [2, 3])
    def test_embedding_merger(self, n_spks, merge_quantity):
        em, ts, mc, mw, spk_ts, gt = generate_toy_data(n_spks, spk_dur=5, perturb_sigma=10)
        em_s, ts_s = split_input_data(em, ts, mc)
        target_speaker_index = 0
        pre_clus_labels = gt
        ndx = torch.where(pre_clus_labels == target_speaker_index)[0]
        pre_embs = em_s[-1]
        affinity_mat = getCosAffinityMatrix(pre_embs)
        cmat = affinity_mat[:, ndx][ndx, :]
        # Check the dimension of the selected affinity values
        assert cmat.shape[0] == cmat.shape[1] == torch.sum(pre_clus_labels == target_speaker_index).item()
        index_2d, rest_inds = get_closest_embeddings(cmat, merge_quantity)
        # Check the most closest affinity value
        assert torch.max(cmat.sum(0)) == cmat.sum(0)[index_2d[0]]
        spk_cluster_labels, emb_ndx = pre_clus_labels[ndx], pre_embs[ndx]
        merged_embs, merged_clus_labels = merge_vectors(index_2d, emb_ndx, spk_cluster_labels)
        # Check the number of merged embeddings and labels
        assert (torch.sum(gt == target_speaker_index).item() - merge_quantity) == merged_clus_labels.shape[0]

    @pytest.mark.unit
    @pytest.mark.parametrize("n_spks", [1, 8])
    @pytest.mark.parametrize("spk_dur", [0.2, 0.25, 0.5, 1, 10])
    def test_cosine_affinity_calculation(self, n_spks, spk_dur):
        em, ts, mc, mw, spk_ts, gt = generate_toy_data(n_spks=n_spks, spk_dur=spk_dur)
        em_s, ts_s = split_input_data(em, ts, mc)
        affinity_mat = getCosAffinityMatrix(em_s[-1])
        # affinity_mat should not contain any nan element
        assert torch.any(torch.isnan(affinity_mat)) == False

    @pytest.mark.unit
    @pytest.mark.parametrize("n_spks", [1, 8])
    @pytest.mark.parametrize("spk_dur", [0.2, 0.25, 0.5, 1, 10])
    def test_cosine_affinity_calculation_scale_interpol(self, n_spks, spk_dur):
        em, ts, mc, mw, spk_ts, gt = generate_toy_data(n_spks=n_spks, spk_dur=spk_dur)
        em_s, ts_s = split_input_data(em, ts, mc)
        embs, _ = get_scale_interpolated_embs(mw, em_s, ts_s)
        affinity_mat = getCosAffinityMatrix(embs)
        # affinity_mat should not contain any nan element
        assert torch.any(torch.isnan(affinity_mat)) == False

    @pytest.mark.unit
    @pytest.mark.parametrize("n_spks", [4, 5, 6])
    @pytest.mark.parametrize("target_speaker_index", [0, 1, 2])
    @pytest.mark.parametrize("merge_quantity", [2, 3])
    def test_embedding_reducer(self, n_spks, target_speaker_index, merge_quantity):
        em, ts, mc, mw, spk_ts, gt = generate_toy_data(n_spks=n_spks, spk_dur=10)
        em_s, ts_s = split_input_data(em, ts, mc)
        merged_embs, merged_clus_labels, _ = run_reducer(
            pre_embs=em_s[-1],
            target_spk_idx=target_speaker_index,
            merge_quantity=merge_quantity,
            pre_clus_labels=gt,
        )
        assert (torch.sum(gt == target_speaker_index).item() - merge_quantity) == merged_clus_labels.shape[0]

    @pytest.mark.unit
    @pytest.mark.parametrize("ntbr", [3])
    @pytest.mark.parametrize("pcl", [torch.tensor([0] * 70 + [1] * 32)])
    @pytest.mark.parametrize("mspb", [25])
    def test_merge_scheduler_2clus(self, ntbr, pcl, mspb):
        class_target_vol = get_merge_quantity(
            num_to_be_removed=ntbr,
            pre_clus_labels=pcl,
            min_count_per_cluster=mspb,
        )
        assert all(class_target_vol == torch.tensor([3, 0]))

    @pytest.mark.unit
    @pytest.mark.parametrize("ntbr", [3])
    @pytest.mark.parametrize("pcl", [torch.tensor([0] * 80 + [1] * 35 + [2] * 32)])
    @pytest.mark.parametrize("mspb", [0, 25])
    def test_merge_scheduler_3clus(self, ntbr, pcl, mspb):
        class_target_vol = get_merge_quantity(
            num_to_be_removed=ntbr,
            pre_clus_labels=pcl,
            min_count_per_cluster=mspb,
        )
        assert all(class_target_vol == torch.tensor([3, 0, 0]))

    @pytest.mark.unit
    @pytest.mark.parametrize("ntbr", [132 - 45])
    @pytest.mark.parametrize("pcl", [torch.tensor([2] * 70 + [0] * 32 + [1] * 27 + [3] * 3)])
    @pytest.mark.parametrize("mspb", [3, 10])
    def test_merge_scheduler_4clus_shuff(self, ntbr, pcl, mspb):
        class_target_vol = get_merge_quantity(
            num_to_be_removed=ntbr,
            pre_clus_labels=pcl,
            min_count_per_cluster=mspb,
        )
        assert all(class_target_vol == torch.tensor([18, 13, 56, 0]))

    @pytest.mark.unit
    @pytest.mark.parametrize("ntbr", [3])
    @pytest.mark.parametrize("pcl", [torch.tensor([0] * 5 + [1] * 4 + [2] * 3)])
    @pytest.mark.parametrize("mspb", [0, 2])
    def test_merge_scheduler_3clus(self, ntbr, pcl, mspb):
        class_target_vol = get_merge_quantity(
            num_to_be_removed=ntbr,
            pre_clus_labels=pcl,
            min_count_per_cluster=mspb,
        )
        assert all(class_target_vol == torch.tensor([2, 1, 0]))

    @pytest.mark.unit
    @pytest.mark.parametrize("ntbr", [2])
    @pytest.mark.parametrize("pcl", [torch.tensor([0] * 7 + [1] * 5 + [2] * 3 + [3] * 5)])
    @pytest.mark.parametrize("mspb", [2])
    def test_merge_scheduler_3clus_repeat(self, ntbr, pcl, mspb):
        class_target_vol = get_merge_quantity(
            num_to_be_removed=ntbr,
            pre_clus_labels=pcl,
            min_count_per_cluster=mspb,
        )
        assert all(class_target_vol == torch.tensor([2, 0, 0, 0]))


class TestClassExport:
    @pytest.mark.unit
    def test_online_segmentor_class_export(self):
        _OnlineSegmentor = torch.jit.script(OnlineSegmentor)
        online_segmentor = _OnlineSegmentor(sample_rate=16000)
        assert isinstance(online_segmentor, OnlineSegmentor)

    @pytest.mark.unit
    def test_online_segmentor_instance_export(self):
        online_segmentor = OnlineSegmentor(sample_rate=16000)
        online_segmentor = torch.jit.script(online_segmentor)
        isinstance(online_segmentor, torch.jit._script.RecursiveScriptClass)

    @pytest.mark.unit
    def test_online_speaker_clustering_instance_export(self):
        online_clus = OnlineSpeakerClustering(
            max_num_speakers=8,
            max_rp_threshold=0.15,
            sparse_search_volume=30,
            history_buffer_size=150,
            current_buffer_size=150,
            cuda=True,
        )
        online_clus = torch.jit.script(online_clus)
        isinstance(online_clus, torch.jit._script.RecursiveScriptClass)

    @pytest.mark.unit
    def test_online_speaker_clustering_instance_export(self):
        offline_speaker_clustering = SpeakerClustering(maj_vote_spk_count=False, min_samples_for_nmesc=0, cuda=True)
        offline_speaker_clustering = torch.jit.script(offline_speaker_clustering)
        isinstance(offline_speaker_clustering, torch.jit._script.RecursiveScriptClass)


class TestGetSubsegments:
    @pytest.mark.unit
    @pytest.mark.parametrize(
        "offset, window, shift, duration, min_subsegment_duration, decimals, use_asr_style_frame_count, sample_rate, feat_per_sec, expected",
        [
            (12.05, 1.5, 0.75, 2.4, 0.01, 2, False, 16000, 100, [[12.05, 1.5], [12.8, 1.5], [13.55, 0.9]]),
            (0, 1.0, 0.5, 0.4, 0.01, 2, False, 16000, 100, [[0, 0.4]]),
            (0, 2.0, 1.0, 1.5, 0.5, 2, False, 16000, 100, [[0, 1.5]]),
            (
                10,
                1.5,
                0.75,
                4.5,
                0.5,
                2,
                False,
                16000,
                100,
                [[10, 1.5], [10.75, 1.5], [11.5, 1.5], [12.25, 1.5], [13.0, 1.5]],
            ),
            (0, 1.5, 0.5, 0.3, 0.01, 2, True, 16000, 100, [[0, 0.3]]),
        ],
    )
    def test_get_subsegments(
        self,
        offset,
        window,
        shift,
        duration,
        min_subsegment_duration,
        decimals,
        use_asr_style_frame_count,
        sample_rate,
        feat_per_sec,
        expected,
    ):

        for is_scriptable in [True, False]:
            if is_scriptable:
                result = get_subsegments_scriptable(
                    offset=offset,
                    window=window,
                    shift=shift,
                    duration=duration,
                )
            else:
                result = get_subsegments(
                    offset=offset,
                    window=window,
                    shift=shift,
                    duration=duration,
                    min_subsegment_duration=min_subsegment_duration,
                    decimals=decimals,
                    use_asr_style_frame_count=use_asr_style_frame_count,
                    sample_rate=sample_rate,
                    feat_per_sec=feat_per_sec,
                )
            result_round = []
            for subsegment in result:
                result_round.append([round(x, decimals) for x in subsegment])
            assert result_round == expected

    @pytest.mark.unit
    def test_min_subsegment_duration_filtering(self):
        result = get_subsegments(
            offset=0,
            window=1.5,
            shift=0.5,
            duration=3,
            min_subsegment_duration=2.0,
            decimals=2,
            use_asr_style_frame_count=False,
        )
        expected = []  # Only subsegments meeting the duration filter should remain
        assert result == expected

    @pytest.mark.unit
    def test_zero_duration(self):
        result = get_subsegments(
            offset=0,
            window=1.0,
            shift=0.5,
            duration=0,
            min_subsegment_duration=0.01,
            decimals=2,
            use_asr_style_frame_count=False,
        )
        assert result == []

    @pytest.mark.unit
    def test_edge_case_short_slice(self):
        result = get_subsegments(
            offset=0,
            window=0.5,
            shift=0.25,  # Shift larger than duration
            duration=0.25,
            min_subsegment_duration=0.01,
            decimals=2,
            use_asr_style_frame_count=False,
        )
        assert result == [[0.0, 0.25]]


class TestDiarizationSegmentationUtils:
    """
    Test segmentation util functions
    """

    @pytest.mark.unit
    @pytest.mark.parametrize(
        "intervals",
        [
            [[1, 4], [2, 6], [8, 10], [15, 18]],
            [[8, 10], [15, 18], [2, 6], [1, 3]],
            [[8, 10], [15, 18], [2, 6], [1, 3], [3, 5]],
            [[8, 10], [8, 8], [15, 18], [2, 6], [1, 6], [2, 4]],
        ],
    )
    @pytest.mark.parametrize("target", [[[1, 6], [8, 10], [15, 18]]])
    def test_merge_int_intervals_ex1(self, intervals, target):
        merged = merge_int_intervals(intervals)
        assert check_range_values(target, merged)

    @pytest.mark.unit
    @pytest.mark.parametrize(
        "intervals",
        [
            [[6, 8], [0, 9], [2, 4], [4, 7]],
            [[0, 9], [6, 8], [4, 7], [2, 4]],
            [[0, 4], [0, 0], [4, 9], [2, 4]],
            [[6, 8], [2, 8], [0, 3], [3, 4], [4, 5], [5, 9]],
        ],
    )
    @pytest.mark.parametrize("target", [[[0, 9]]])
    def test_merge_int_intervals_ex2(self, intervals, target):
        merged = merge_int_intervals(intervals)
        assert check_range_values(target, merged)

    @pytest.mark.unit
    @pytest.mark.parametrize("intervals", [[[0, 1], [1, 9]], [[0, 0], [0, 9]], [[0, 9], [0, 9]]])
    @pytest.mark.parametrize("target", [[[0, 9]]])
    def test_merge_int_intervals_edge_test(self, intervals, target):
        merged = merge_int_intervals(intervals)
        assert check_range_values(target, merged)

    @pytest.mark.unit
    @pytest.mark.parametrize("rangeA", [[1.0, 2.0]])
    @pytest.mark.parametrize("rangeB", [[0.5, 1.5], [0.9999, 1.0001]])
    def test_is_overlap_true(self, rangeA, rangeB):
        assert is_overlap(rangeA, rangeB)

    @pytest.mark.unit
    @pytest.mark.parametrize("rangeA", [[1.0, 2.0]])
    @pytest.mark.parametrize("rangeB", [[2.0, 2.5], [-1.0, 1.00]])
    def test_is_overlap_false(self, rangeA, rangeB):
        assert not is_overlap(rangeA, rangeB)

    @pytest.mark.unit
    @pytest.mark.parametrize("x", [1.0, 2.3456])
    @pytest.mark.parametrize("decimals", [1, 2, 3, 4])
    def test_fl2int(self, x, decimals):
        assert fl2int(x, decimals) == round(x * 10**decimals, 0)

    @pytest.mark.unit
    @pytest.mark.parametrize("x", [1234])
    @pytest.mark.parametrize(
        "decimals",
        [
            1,
            2,
            3,
            4,
        ],
    )
    def test_int2fl(self, x, decimals):
        assert abs(int2fl(x, decimals) - round(x / (10**decimals), decimals)) < (10 ** -(decimals + 1))

    @pytest.mark.unit
    def test_merge_float_intervals_edge_margin_test(self):
        intervals = [[0.0, 1.0], [1.0, 2.0]]

        target_0 = [[0.0, 2.0]]
        merged_0 = merge_float_intervals(intervals, margin=0)
        assert check_range_values(target_0, merged_0)

        target_1 = [[0.0, 1.0], [1.0, 2.0]]
        merged_1 = merge_float_intervals(intervals, margin=1)
        assert check_range_values(target_1, merged_1)

        target_2 = [[0.0, 1.0], [1.0, 2.0]]
        merged_2 = merge_float_intervals(intervals, margin=2)
        assert check_range_values(target_2, merged_2)

    @pytest.mark.unit
    @pytest.mark.parametrize(
        "intervals",
        [
            [[0.25, 1.7], [1.5, 3.0], [2.8, 5.0], [5.5, 10.0]],
            [[0.25, 5.0], [5.5, 10.0], [1.5, 3.5]],
            [[5.5, 8.05], [8.0, 10.0], [0.25, 5.0]],
            [[0.25, 3.0], [1.5, 3.0], [5.5, 10.0], [2.8, 5.0]],
            [[0.25, 1.7], [1.5, 3.0], [2.8, 5.0], [5.5, 10.0]],
        ],
    )
    @pytest.mark.parametrize("target", [[[0.25, 5.0], [5.5, 10.0]]])
    def test_merge_float_overlaps(self, intervals, target):
        merged = merge_float_intervals(intervals)
        assert check_range_values(target, merged)

    @pytest.mark.unit
    def test_get_speech_labels_for_update(self):
        frame_start = 3.0
        buffer_end = 6.0
        cumulative_speech_labels = torch.tensor([[0.0000, 3.7600]])
        vad_timestamps = torch.tensor([[0.9600, 4.8400]])
        cursor_for_old_segments = 1.0
        speech_labels_for_update, cumulative_speech_labels = get_speech_labels_for_update(
            frame_start,
            buffer_end,
            cumulative_speech_labels,
            vad_timestamps,
            cursor_for_old_segments,
        )
        assert (speech_labels_for_update - torch.tensor([[1.0000, 3.7600]])).sum() < 1e-8
        assert (cumulative_speech_labels - torch.tensor([[0.9600, 4.8400]])).sum() < 1e-8

        # Check if the ranges are containing faulty values
        assert check_ranges(speech_labels_for_update)
        assert check_ranges(cumulative_speech_labels)

    @pytest.mark.unit
    def test_get_online_subsegments_from_buffer(self):
        torch.manual_seed(0)
        sample_rate = 16000
        speech_labels_for_update = torch.Tensor([[0.0000, 3.7600]])
        audio_buffer = torch.randn(5 * sample_rate)
        segment_indexes = []
        window = 2.0
        shift = 1.0
        slice_length = int(window * sample_rate)
        range_target = [[0.0, 2.0], [1.0, 3.0], [2.0, 3.76]]
        sigs_list, sig_rangel_list, sig_indexes = get_online_subsegments_from_buffer(
            buffer_start=0.0,
            buffer_end=5.0,
            sample_rate=sample_rate,
            speech_labels_for_update=speech_labels_for_update,
            audio_buffer=audio_buffer,
            segment_indexes=segment_indexes,
            window=window,
            shift=shift,
        )
        assert check_range_values(target=range_target, source=sig_rangel_list)
        for k, rg in enumerate(sig_rangel_list):
            signal = get_target_sig(audio_buffer, rg[0], rg[1], slice_length, sample_rate)
            if len(signal) < int(window * sample_rate):
                signal = repeat_signal(signal, len(signal), slice_length)
            assert len(signal) == int(slice_length), "Length mismatch"
            assert (np.abs(signal - sigs_list[k])).sum() < 1e-8, "Audio stream mismatch"
        assert (torch.tensor(sig_indexes) - torch.arange(len(range_target))).sum() < 1e-8, "Segment index mismatch"

    @pytest.mark.unit
    @pytest.mark.parametrize("frame_start", [3.0])
    @pytest.mark.parametrize("segment_range_ts", [[[0.0, 2.0]]])
    @pytest.mark.parametrize("gt_cursor_for_old_segments", [3.0])
    @pytest.mark.parametrize("gt_cursor_index", [1])
    def test_get_new_cursor_for_update_mulsegs_ex1(
        self, frame_start, segment_range_ts, gt_cursor_for_old_segments, gt_cursor_index
    ):
        cursor_for_old_segments, cursor_index = get_new_cursor_for_update(frame_start, segment_range_ts)
        assert cursor_for_old_segments == gt_cursor_for_old_segments
        assert cursor_index == gt_cursor_index

    @pytest.mark.unit
    @pytest.mark.parametrize("target_range", [[1.0, 4.0]])
    @pytest.mark.parametrize(
        "source_range_list", [[[2.0, 3.0], [3.0, 4.0]], [[0.0, 2.0], [3.0, 5.0]], [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]]
    )
    def get_sub_range_list(self, target_range, source_range_list):
        sub_range_list = get_sub_range_list(target_range, source_range_list)
        assert sub_range_list == [[2.0, 3.0], [3.0, 4.0]]

    @pytest.mark.unit
    @pytest.mark.parametrize("source_range_list", [[[0.0, 2.0]], [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]])
    def test_tensor_to_list(self, source_range_list):
        a_range_tensor = torch.tensor(source_range_list)
        converted_list = tensor_to_list(a_range_tensor)
        assert source_range_list == converted_list

    @pytest.mark.unit
    @pytest.mark.parametrize(
        "buffer_start, buffer_end, subsegments, ind_offset, window, sample_rate",
        [
            (0.0, 2.0, [[0.5, 1.0], [1.5, 2.0]], 0, 0.1, 16000),
            (0.0, 5.0, [[0.5, 2.5], [2.7, 5.0]], 0, 1.0, 16000),
        ],
    )
    def test_get_online_segments_from_slices(
        self, buffer_start, buffer_end, subsegments, ind_offset, window, sample_rate
    ):
        sig = torch.randn(int(sample_rate * buffer_end))
        ind_offset, sigs_list, sig_rangel_list, sig_indexes = get_online_segments_from_slices(
            sig, buffer_start, buffer_end, subsegments, ind_offset, window, sample_rate
        )
        assert ind_offset == 2
        assert len(sigs_list) == 2
        assert len(sig_rangel_list) == 2
        assert len(sig_indexes) == 2


class TestClusteringUtilFunctions:
    @pytest.mark.parametrize("p_value", [1, 5, 9])
    @pytest.mark.parametrize("N", [9, 20])
    @pytest.mark.parametrize("mask_method", ['binary', 'sigmoid', 'drop'])
    def test_get_k_neighbors_connections(self, p_value: int, N: int, mask_method: str, seed=0):
        torch.manual_seed(seed)
        random_mat = torch.rand(N, N)
        affinity_mat = 0.5 * (random_mat + random_mat.T)
        affinity_mat = affinity_mat / torch.max(affinity_mat)
        binarized_affinity_mat = getKneighborsConnections(affinity_mat, p_value, mask_method)
        if mask_method == 'binary':
            assert all(binarized_affinity_mat.sum(dim=0) == float(p_value))
        elif mask_method == 'sigmoid':
            assert all(binarized_affinity_mat.sum(dim=0) <= float(p_value))
        elif mask_method == 'drop':
            assert all(binarized_affinity_mat.sum(dim=0) <= float(p_value))

    @pytest.mark.unit
    @pytest.mark.parametrize("Y_aggr", [torch.tensor([0, 1, 0, 1])])
    @pytest.mark.parametrize("chunk_cluster_count, embeddings_per_chunk", [(2, 50)])
    @pytest.mark.parametrize("window_range_list", [[[0, 1], [1, 2], [2, 3], [3, 4]]])
    @pytest.mark.parametrize(
        "absolute_merge_mapping",
        [[[torch.tensor([]), torch.tensor([0, 2])], [torch.tensor([]), torch.tensor([1, 3])]]],
    )
    @pytest.mark.parametrize("org_len", [4])
    def test_unpack_labels(
        self, Y_aggr, window_range_list, absolute_merge_mapping, chunk_cluster_count, embeddings_per_chunk, org_len
    ):
        expected_result = Y_aggr
        longform_speaker_clustering = LongFormSpeakerClustering(cuda=False)
        output = longform_speaker_clustering.unpack_labels(Y_aggr, window_range_list, absolute_merge_mapping, org_len)
        assert torch.equal(output, expected_result)


class TestSpeakerClustering:
    """
    Test speaker clustering module
    """

    @pytest.mark.unit
    @pytest.mark.parametrize("cuda", [True, False])
    def test_offline_clus_script_save_load(self, cuda):
        exported_filename = 'speaker_clustering_script.pt'
        speaker_clustering_python = SpeakerClustering(maj_vote_spk_count=False, cuda=cuda)
        speaker_clustering_scripted_source = torch.jit.script(speaker_clustering_python)
        torch.jit.save(speaker_clustering_scripted_source, exported_filename)
        assert os.path.exists(exported_filename)
        os.remove(exported_filename)
        assert not os.path.exists(exported_filename)

    @pytest.mark.unit
    @pytest.mark.parametrize("cuda", [True, False])
    def test_online_clus_script_save_load(self, cuda):
        exported_filename = 'speaker_clustering_script.pt'
        speaker_clustering_python = OnlineSpeakerClustering(
            max_num_speakers=8,
            max_rp_threshold=0.15,
            sparse_search_volume=30,
            history_buffer_size=150,
            current_buffer_size=150,
            cuda=cuda,
        )
        speaker_clustering_scripted_source = torch.jit.script(speaker_clustering_python)
        torch.jit.save(speaker_clustering_scripted_source, exported_filename)
        assert os.path.exists(exported_filename)
        os.remove(exported_filename)
        assert not os.path.exists(exported_filename)

    @pytest.mark.run_only_on('GPU')
    @pytest.mark.unit
    @pytest.mark.parametrize("n_spks", [1, 2, 3, 4, 5, 6, 7])
    @pytest.mark.parametrize("total_sec, SSV, perturb_sigma, seed", [(30, 10, 0.1, 0)])
    @pytest.mark.parametrize("jit_script", [False, True])
    def test_offline_speaker_clustering(self, n_spks, total_sec, SSV, perturb_sigma, seed, jit_script, cuda=True):
        spk_dur = total_sec / n_spks
        em, ts, mc, mw, spk_ts, gt = generate_toy_data(
            n_spks=n_spks, spk_dur=spk_dur, perturb_sigma=perturb_sigma, torch_seed=seed
        )
        offline_speaker_clustering = SpeakerClustering(maj_vote_spk_count=False, cuda=cuda)
        assert isinstance(offline_speaker_clustering, SpeakerClustering)
        if jit_script:
            offline_speaker_clustering = torch.jit.script(offline_speaker_clustering)

        Y_out = offline_speaker_clustering.forward_infer(
            embeddings_in_scales=em,
            timestamps_in_scales=ts,
            multiscale_segment_counts=mc,
            multiscale_weights=mw,
            oracle_num_speakers=-1,
            max_num_speakers=8,
            enhanced_count_thres=40,
            sparse_search_volume=SSV,
            max_rp_threshold=0.15,
            fixed_thres=-1.0,
        )
        permuted_Y = stitch_cluster_labels(Y_old=gt, Y_new=Y_out)
        permuted_Y = permuted_Y.to(gt.device)
        # mc[-1] is the number of base scale segments
        assert len(set(permuted_Y.tolist())) == n_spks
        assert Y_out.shape[0] == mc[-1]
        assert all(permuted_Y == gt)

    @pytest.mark.run_only_on('CPU')
    @pytest.mark.unit
    @pytest.mark.parametrize("n_spks", [1, 2, 3, 4, 5, 6, 7])
    @pytest.mark.parametrize("total_sec, SSV, perturb_sigma, seed", [(30, 10, 0.1, 0)])
    @pytest.mark.parametrize("jit_script", [False, True])
    def test_offline_speaker_clustering_cpu(self, n_spks, total_sec, SSV, perturb_sigma, seed, jit_script, cuda=False):
        self.test_offline_speaker_clustering(n_spks, total_sec, SSV, perturb_sigma, seed, jit_script, cuda=cuda)

    @pytest.mark.run_only_on('CPU')
    @pytest.mark.unit
    @pytest.mark.parametrize("n_spks", [1])
    @pytest.mark.parametrize("spk_dur", [0.25, 0.5, 0.75, 1, 1.5, 2])
    @pytest.mark.parametrize("SSV, enhanced_count_thres, min_samples_for_nmesc", [(5, 40, 6)])
    @pytest.mark.parametrize("seed", [0])
    def test_offline_speaker_clustering_very_short_cpu(
        self,
        n_spks,
        spk_dur,
        SSV,
        enhanced_count_thres,
        min_samples_for_nmesc,
        seed,
    ):
        em, ts, mc, mw, spk_ts, gt = generate_toy_data(
            n_spks=n_spks, spk_dur=spk_dur, perturb_sigma=0.1, torch_seed=seed
        )
        offline_speaker_clustering = SpeakerClustering(maj_vote_spk_count=False, min_samples_for_nmesc=0, cuda=False)
        assert isinstance(offline_speaker_clustering, SpeakerClustering)
        Y_out = offline_speaker_clustering.forward_infer(
            embeddings_in_scales=em,
            timestamps_in_scales=ts,
            multiscale_segment_counts=mc,
            multiscale_weights=mw,
            oracle_num_speakers=-1,
            max_num_speakers=8,
            enhanced_count_thres=enhanced_count_thres,
            sparse_search_volume=SSV,
            max_rp_threshold=0.15,
            fixed_thres=-1.0,
        )
        permuted_Y = stitch_cluster_labels(Y_old=gt, Y_new=Y_out)
        permuted_Y = permuted_Y.to(gt.device)
        # mc[-1] is the number of base scale segments
        assert len(set(permuted_Y.tolist())) == n_spks
        assert Y_out.shape[0] == mc[-1]
        assert all(permuted_Y == gt)

    @pytest.mark.run_only_on('GPU')
    @pytest.mark.unit
    @pytest.mark.parametrize("spk_dur", [0.25, 0.5, 0.75, 1, 2, 4])
    @pytest.mark.parametrize("n_spks, SSV, enhanced_count_thres, min_samples_for_nmesc", [(1, 5, 40, 6)])
    @pytest.mark.parametrize("seed", [0])
    def test_offline_speaker_clustering_very_short_gpu(
        self,
        n_spks,
        spk_dur,
        SSV,
        enhanced_count_thres,
        min_samples_for_nmesc,
        seed,
    ):
        em, ts, mc, mw, spk_ts, gt = generate_toy_data(
            n_spks=n_spks, spk_dur=spk_dur, perturb_sigma=0.1, torch_seed=seed
        )
        offline_speaker_clustering = SpeakerClustering(maj_vote_spk_count=False, min_samples_for_nmesc=0, cuda=True)
        assert isinstance(offline_speaker_clustering, SpeakerClustering)
        Y_out = offline_speaker_clustering.forward_infer(
            embeddings_in_scales=em,
            timestamps_in_scales=ts,
            multiscale_segment_counts=mc,
            multiscale_weights=mw,
            oracle_num_speakers=-1,
            max_num_speakers=8,
            enhanced_count_thres=enhanced_count_thres,
            sparse_search_volume=SSV,
            max_rp_threshold=0.15,
            fixed_thres=-1.0,
        )
        permuted_Y = stitch_cluster_labels(Y_old=gt, Y_new=Y_out)
        permuted_Y = permuted_Y.to(gt.device)
        # mc[-1] is the number of base scale segments
        assert Y_out.shape[0] == mc[-1]
        assert all(permuted_Y == gt)

    @pytest.mark.run_only_on('CPU')
    @pytest.mark.unit
    @pytest.mark.parametrize("n_spks, SSV, enhanced_count_thres, min_samples_for_nmesc", [(2, 5, 40, 6)])
    @pytest.mark.parametrize("spk_dur, chunk_cluster_count, embeddings_per_chunk", [(120, 4, 50), (240, 4, 100)])
    @pytest.mark.parametrize("seed", [0])
    @pytest.mark.parametrize("jit_script", [False, True])
    def test_longform_speaker_clustering_cpu(
        self,
        n_spks,
        spk_dur,
        SSV,
        enhanced_count_thres,
        min_samples_for_nmesc,
        chunk_cluster_count,
        embeddings_per_chunk,
        jit_script,
        seed,
    ):
        em, ts, mc, mw, spk_ts, gt = generate_toy_data(
            n_spks=n_spks, spk_dur=spk_dur, perturb_sigma=0.1, torch_seed=seed
        )
        longform_speaker_clustering = LongFormSpeakerClustering(cuda=False)
        if jit_script:
            longform_speaker_clustering = torch.jit.script(longform_speaker_clustering)
        else:
            assert isinstance(longform_speaker_clustering, LongFormSpeakerClustering)
        Y_out = longform_speaker_clustering.forward_infer(
            embeddings_in_scales=em,
            timestamps_in_scales=ts,
            multiscale_segment_counts=mc,
            multiscale_weights=mw,
            oracle_num_speakers=-1,
            max_num_speakers=n_spks,
            enhanced_count_thres=enhanced_count_thres,
            sparse_search_volume=SSV,
            max_rp_threshold=0.15,
            fixed_thres=-1.0,
            chunk_cluster_count=chunk_cluster_count,
            embeddings_per_chunk=embeddings_per_chunk,
        )
        permuted_Y = stitch_cluster_labels(Y_old=gt, Y_new=Y_out)
        permuted_Y = permuted_Y.to(gt.device)

        # mc[-1] is the number of base scale segments
        assert Y_out.shape[0] == mc[-1]
        assert all(permuted_Y == gt)

    @pytest.mark.run_only_on('GPU')
    @pytest.mark.unit
    @pytest.mark.parametrize("n_spks, SSV, enhanced_count_thres, min_samples_for_nmesc", [(2, 5, 40, 6)])
    @pytest.mark.parametrize("spk_dur, chunk_cluster_count, embeddings_per_chunk", [(120, 4, 50), (240, 4, 100)])
    @pytest.mark.parametrize("seed", [0])
    @pytest.mark.parametrize("jit_script", [False, True])
    def test_longform_speaker_clustering_gpu(
        self,
        n_spks,
        spk_dur,
        SSV,
        enhanced_count_thres,
        min_samples_for_nmesc,
        chunk_cluster_count,
        embeddings_per_chunk,
        jit_script,
        seed,
    ):
        em, ts, mc, mw, spk_ts, gt = generate_toy_data(
            n_spks=n_spks, spk_dur=spk_dur, perturb_sigma=0.1, torch_seed=seed
        )
        longform_speaker_clustering = LongFormSpeakerClustering(cuda=True)

        if jit_script:
            longform_speaker_clustering = torch.jit.script(longform_speaker_clustering)
        else:
            assert isinstance(longform_speaker_clustering, LongFormSpeakerClustering)

        Y_out = longform_speaker_clustering.forward_infer(
            embeddings_in_scales=em,
            timestamps_in_scales=ts,
            multiscale_segment_counts=mc,
            multiscale_weights=mw,
            oracle_num_speakers=-1,
            max_num_speakers=n_spks,
            enhanced_count_thres=enhanced_count_thres,
            sparse_search_volume=SSV,
            max_rp_threshold=0.15,
            fixed_thres=-1.0,
            chunk_cluster_count=chunk_cluster_count,
            embeddings_per_chunk=embeddings_per_chunk,
        )
        permuted_Y = stitch_cluster_labels(Y_old=gt, Y_new=Y_out)
        permuted_Y = permuted_Y.to(gt.device)

        # mc[-1] is the number of base scale segments
        assert Y_out.shape[0] == mc[-1]
        assert all(permuted_Y == gt)

    @pytest.mark.run_only_on('GPU')
    @pytest.mark.unit
    @pytest.mark.parametrize("n_spks", [1, 2, 3])
    @pytest.mark.parametrize("total_sec, buffer_size, sigma", [(30, 30, 0.1)])
    @pytest.mark.parametrize("seed", [0])
    @pytest.mark.parametrize("jit_script", [False, True])
    def test_online_speaker_clustering(self, n_spks, total_sec, buffer_size, sigma, seed, jit_script, cuda=True):
        step_per_frame = 2
        spk_dur = total_sec / n_spks
        em, ts, mc, _, _, gt = generate_toy_data(n_spks, spk_dur=spk_dur, perturb_sigma=sigma, torch_seed=seed)
        em_s, ts_s = split_input_data(em, ts, mc)

        emb_gen = em_s[-1]
        segment_indexes = ts_s[-1]
        if cuda:
            device = torch.cuda.current_device()
            emb_gen, segment_indexes = emb_gen.to(device), segment_indexes.to(device)

        history_buffer_size = buffer_size
        current_buffer_size = buffer_size

        online_clus = OnlineSpeakerClustering(
            max_num_speakers=8,
            max_rp_threshold=0.15,
            sparse_search_volume=30,
            history_buffer_size=history_buffer_size,
            current_buffer_size=current_buffer_size,
            cuda=cuda,
        )
        if jit_script:
            online_clus = torch.jit.script(online_clus)

        n_frames = int(emb_gen.shape[0] / step_per_frame)
        evaluation_list = []

        # Simulate online speaker clustering
        for frame_index in range(n_frames):
            curr_emb = emb_gen[0 : (frame_index + 1) * step_per_frame]
            base_segment_indexes = torch.arange(curr_emb.shape[0]).to(curr_emb.device)
            # Check history_buffer_size and history labels
            assert (
                online_clus.history_embedding_buffer_emb.shape[0] <= history_buffer_size
            ), "History buffer size error"
            assert (
                online_clus.history_embedding_buffer_emb.shape[0]
                == online_clus.history_embedding_buffer_label.shape[0]
            )

            # Call clustering function
            merged_clus_labels = online_clus.forward_infer(
                curr_emb=curr_emb, base_segment_indexes=base_segment_indexes, frame_index=frame_index, cuda=cuda
            )

            # Resolve permutations
            assert len(merged_clus_labels) == (frame_index + 1) * step_per_frame
            # Resolve permutation issue by using stitch_cluster_labels function
            merged_clus_labels = merged_clus_labels.cpu()
            merged_clus_labels = stitch_cluster_labels(Y_old=gt[: len(merged_clus_labels)], Y_new=merged_clus_labels)
            evaluation_list.extend(list(merged_clus_labels == gt[: len(merged_clus_labels)]))

        assert online_clus.is_online
        cumul_label_acc = sum(evaluation_list) / len(evaluation_list)
        assert cumul_label_acc > 0.9

    @pytest.mark.run_only_on('CPU')
    @pytest.mark.unit
    @pytest.mark.parametrize("n_spks, total_sec, buffer_size, sigma, seed", [(3, 30, 30, 0.1, 0)])
    @pytest.mark.parametrize("jit_script", [False, True])
    def test_online_speaker_clustering_cpu(self, n_spks, total_sec, buffer_size, sigma, seed, jit_script, cuda=False):
        self.test_online_speaker_clustering(n_spks, total_sec, buffer_size, sigma, seed, jit_script, cuda)


class TestLinearSumAssignmentAlgorithm:
    @pytest.mark.unit
    def test_lsa_solver_export_test(self):
        cost_matrix = torch.randint(0, 10, (3, 3))
        solver = LinearSumAssignmentSolver(cost_matrix)
        solver = torch.jit.script(solver)
        assert isinstance(solver, torch.jit._script.RecursiveScriptClass)

    @pytest.mark.unit
    @pytest.mark.parametrize(
        "cost_matrix",
        [torch.tensor([[7, 6, 2, 9, 2], [6, 2, 1, 3, 9], [5, 6, 8, 9, 5], [6, 8, 5, 8, 6], [9, 5, 6, 4, 7]])],
    )
    def test_linear_sum_assignment_algorithm_cost_matrix(self, cost_matrix):
        """
        Test the linear sum assignment algorithm with a cost matrix

        Compare with the scipy implementation and make sure the final cost is the same.
        NOTE: There could be multiple solutions with the same cost in linear sum assignment problem.
              This test only checks if the cost is the same.
        """
        row_ind_nm, col_ind_nm = nemo_linear_sum_assignment(cost_matrix)
        row_ind_sc, col_ind_sc = scipy_linear_sum_assignment(cost_matrix.cpu().numpy())
        cost_nm = sum(cost_matrix[row_ind_nm, col_ind_nm])
        cost_sc = sum(cost_matrix[row_ind_sc, col_ind_sc])
        assert cost_nm == cost_sc

    @pytest.mark.unit
    @pytest.mark.parametrize("seed", [0, 1])
    @pytest.mark.parametrize("mat_size", [1, 2, 4, 8])
    def test_linear_sum_assignment_algorithm_random_matrix(self, seed, mat_size):
        torch.manual_seed(seed)
        cost_matrix = torch.randint(0, 10, (mat_size, mat_size))
        self.test_linear_sum_assignment_algorithm_cost_matrix(cost_matrix)
