# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import json
import os
from dataclasses import dataclass
from typing import Any, Dict, List

import pytest
import torch
from torch.utils.data import DataLoader, Dataset

from nemo.collections.asr.parts.mixins.diarization import DiarizeConfig, SpkDiarizationMixin


class DummyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = torch.nn.Linear(1, 1)

        self.execution_count = 0
        self.flag_begin = False

    def forward(self, x):
        # Input: [1, 1] Output = [1, 1
        out = self.encoder(x)
        return out


@pytest.mark.with_downloads()
@pytest.fixture()
def audio_files(test_data_dir):
    """
    Returns a list of audio files for testing.
    """
    import soundfile as sf

    audio_file1 = os.path.join(test_data_dir, "an4_speaker", "an4", "wav", "an4_clstk", "fash", "an251-fash-b.wav")
    audio_file2 = os.path.join(test_data_dir, "an4_speaker", "an4", "wav", "an4_clstk", "ffmm", "cen1-ffmm-b.wav")

    audio1, _ = sf.read(audio_file1, dtype='float32')
    audio2, _ = sf.read(audio_file2, dtype='float32')

    return audio1, audio2


class DiarizableDummy(DummyModel, SpkDiarizationMixin):
    def _diarize_on_begin(self, audio, diarcfg: DiarizeConfig):
        super()._diarize_on_begin(audio, diarcfg)
        self.flag_begin = True

    def _diarize_input_manifest_processing(self, audio_files: List[str], temp_dir: str, diarcfg: DiarizeConfig):
        # Create a dummy manifest
        manifest_path = os.path.join(temp_dir, 'dummy_manifest.json')
        with open(manifest_path, 'w', encoding='utf-8') as fp:
            for audio_file in audio_files:
                entry = {'audio_filepath': audio_file, 'duration': 100000, 'text': ''}
                fp.write(json.dumps(entry) + '\n')

        ds_config = {
            'paths2audio_files': audio_files,
            'batch_size': diarcfg.batch_size,
            'temp_dir': temp_dir,
            'session_len_sec': diarcfg.session_len_sec,
            'num_workers': diarcfg.num_workers,
        }
        return ds_config

    def _setup_diarize_dataloader(self, config: Dict) -> DataLoader:
        class DummyDataset(Dataset):
            def __init__(self, audio_files: List[str], config: Dict):
                self.audio_files = audio_files
                self.config = config

            def __getitem__(self, index):
                data = self.audio_files[index]
                data = torch.tensor([float(data)]).view(1)
                return data

            def __len__(self):
                return len(self.audio_files)

        dataset = DummyDataset(config['paths2audio_files'], config)

        return DataLoader(
            dataset=dataset,
            batch_size=config['batch_size'],
            num_workers=config['num_workers'],
            pin_memory=False,
            drop_last=False,
        )

    def _diarize_forward(self, batch: Any):
        output = self(batch)
        return output

    def _diarize_output_processing(self, outputs, uniq_ids, diarcfg: DiarizeConfig):
        self.execution_count += 1

        result = []
        for output in outputs:
            result.append(float(output.item()))

        if hasattr(diarcfg, 'output_type') and diarcfg.output_type == 'dict':
            results = {'output': result}
            return results

        if hasattr(diarcfg, 'output_type') and diarcfg.output_type == 'dict2':
            results = [{'output': res} for res in result]
            return results

        if hasattr(diarcfg, 'output_type') and diarcfg.output_type == 'tuple':
            result = tuple(result)
            return result

        # Pass list of results by default
        return result


class DummyDataset(Dataset):
    def __init__(self, audio_tensors: List[str], config: Dict = None):
        self.audio_tensors = audio_tensors
        self.config = config

    def __getitem__(self, index):
        data = self.audio_tensors[index]
        samples = torch.tensor(data)
        # Calculate seq length
        seq_len = torch.tensor(samples.shape[0], dtype=torch.long)

        # Dummy text tokens
        targets = torch.tensor([0], dtype=torch.long)
        targets_len = torch.tensor(1, dtype=torch.long)
        return (samples, seq_len, targets, targets_len)

    def __len__(self):
        return len(self.audio_tensors)


@pytest.fixture()
def dummy_model():
    return DiarizableDummy()


class TestSpkDiarizationMixin:
    @pytest.mark.unit
    def test_constructor_non_instance(self):
        model = DummyModel()
        assert not isinstance(model, SpkDiarizationMixin)
        assert not hasattr(model, 'diarize')

    @pytest.mark.unit
    def test_diarize(self, dummy_model):
        dummy_model = dummy_model.eval()
        dummy_model.encoder.weight.data.fill_(1.0)
        dummy_model.encoder.bias.data.fill_(0.0)

        audio = ['1.0', '2.0', '3.0']
        outputs = dummy_model.diarize(audio, batch_size=1)
        assert len(outputs) == 3
        assert outputs[0] == 1.0
        assert outputs[1] == 2.0
        assert outputs[2] == 3.0

    @pytest.mark.unit
    def test_diarize_generator(self, dummy_model):
        dummy_model = dummy_model.eval()
        dummy_model.encoder.weight.data.fill_(1.0)
        dummy_model.encoder.bias.data.fill_(0.0)

        audio = ['1.0', '2.0', '3.0']

        diarize_config = DiarizeConfig(batch_size=1)
        generator = dummy_model.diarize_generator(audio, override_config=diarize_config)

        outputs = []
        index = 1
        for result in generator:
            outputs.extend(result)
            assert len(result) == 1
            assert len(outputs) == index
            index += 1

        assert len(outputs) == 3
        assert outputs[0] == 1.0
        assert outputs[1] == 2.0
        assert outputs[2] == 3.0

    @pytest.mark.unit
    def test_diarize_generator_explicit_stop_check(self, dummy_model):
        dummy_model = dummy_model.eval()
        dummy_model.encoder.weight.data.fill_(1.0)
        dummy_model.encoder.bias.data.fill_(0.0)

        audio = ['1.0', '2.0', '3.0']

        diarize_config = DiarizeConfig(batch_size=1)
        generator = dummy_model.diarize_generator(audio, override_config=diarize_config)

        outputs = []
        index = 1
        while True:
            try:
                result = next(generator)
            except StopIteration:
                break
            outputs.extend(result)
            assert len(result) == 1
            assert len(outputs) == index
            index += 1

        assert len(outputs) == 3
        assert outputs[0] == 1.0
        assert outputs[1] == 2.0
        assert outputs[2] == 3.0

    @pytest.mark.unit
    def test_diarize_check_flags(self, dummy_model):
        dummy_model = dummy_model.eval()

        audio = ['1.0', '2.0', '3.0']
        dummy_model.diarize(audio, batch_size=1)
        assert dummy_model.flag_begin

    @pytest.mark.unit
    def test_transribe_override_config_incorrect(self, dummy_model):
        # Not subclassing DiarizeConfig
        @dataclass
        class OverrideConfig:
            batch_size: int = 1
            output_type: str = 'dict'

        dummy_model = dummy_model.eval()

        audio = [1.0, 2.0, 3.0]
        override_cfg = OverrideConfig(batch_size=1, output_type='dict')
        with pytest.raises(ValueError):
            _ = dummy_model.diarize(audio, override_config=override_cfg)

    @pytest.mark.unit
    def test_transribe_override_config_correct(self, dummy_model):
        @dataclass
        class OverrideConfig(DiarizeConfig):
            output_type: str = 'dict'
            verbose: bool = False

        dummy_model = dummy_model.eval()
        dummy_model.encoder.weight.data.fill_(1.0)
        dummy_model.encoder.bias.data.fill_(0.0)

        audio = ['1.0', '2.0', '3.0']
        override_cfg = OverrideConfig(batch_size=1, output_type='list')
        outputs = dummy_model.diarize(audio, override_config=override_cfg)

        assert isinstance(outputs, list)
        assert len(outputs) == 3
        assert outputs[0] == 1.0
        assert outputs[1] == 2.0
        assert outputs[2] == 3.0
