"""
Transcribe Hindi audio using fine-tuned Nemotron ASR model.
Model: BayAreaBoys/nemotron-hindi (fine-tuned from nvidia/nemotron-speech-streaming-en-0.6b)
Architecture: FastConformer Cache-Aware RNNT — expects 16kHz mono audio
"""

import os
import time
import torch
import soundfile as sf
import numpy as np
from dotenv import load_dotenv
from huggingface_hub import hf_hub_download
import nemo.collections.asr as nemo_asr

load_dotenv()

HF_TOKEN = os.getenv("HF_TOKEN")
HF_MODEL = os.getenv("HF_MODEL")  # BayAreaBoys/nemotron-hindi
AUDIO_FILES = ["1.wav", "2.wav"]
TARGET_SR = 16000  # Nemotron expects 16kHz mono


def resample_if_needed(filepath: str, target_sr: int = TARGET_SR) -> str:
    """Resample audio to target sample rate if needed. Returns path to 16kHz file."""
    from scipy.signal import resample_poly
    from math import gcd

    data, sr = sf.read(filepath, dtype="float32")

    # Convert to mono if stereo
    if data.ndim > 1:
        data = data.mean(axis=1)

    if sr != target_sr:
        print(f"  Resampling {filepath}: {sr}Hz -> {target_sr}Hz")
        # Rational resampling: 24000->16000 = down by 3, up by 2
        g = gcd(sr, target_sr)
        up, down = target_sr // g, sr // g
        data = resample_poly(data, up, down).astype(np.float32)
        out_path = filepath.replace(".wav", f"_{target_sr}.wav")
        sf.write(out_path, data, target_sr)
        return out_path

    return filepath


def main():
    print(f"=" * 60)
    print(f"Hindi Nemotron ASR Transcription")
    print(f"Model: {HF_MODEL}")
    print(f"=" * 60)

    # === Step 1: Download .nemo checkpoint from HuggingFace ===
    print(f"\n[1/3] Downloading model from HuggingFace: {HF_MODEL}")
    t0 = time.time()
    model_path = hf_hub_download(
        repo_id=HF_MODEL,
        filename="final_model.nemo",
        token=HF_TOKEN,
    )
    print(f"  Downloaded to: {model_path} ({time.time() - t0:.1f}s)")

    # === Step 2: Load NeMo ASR model ===
    print(f"\n[2/3] Loading NeMo ASR model...")
    t0 = time.time()
    asr_model = nemo_asr.models.ASRModel.restore_from(model_path)

    # WHY: NeMo 2.6.2 hardcodes use_lhotse=True in _setup_transcribe_dataloader,
    # which hits a lhotse DynamicCutSampler compat bug. Monkey-patch to disable it.
    original_setup = asr_model._setup_transcribe_dataloader.__func__

    def patched_setup(self_model, config):
        """Patch: force use_lhotse=False to use standard NeMo dataloader"""
        from omegaconf import DictConfig
        if 'manifest_filepath' in config:
            manifest_filepath = config['manifest_filepath']
            batch_size = config['batch_size']
        else:
            manifest_filepath = os.path.join(config['temp_dir'], 'manifest.json')
            batch_size = min(config['batch_size'], len(config['paths2audio_files']))

        dl_config = {
            'manifest_filepath': manifest_filepath,
            'sample_rate': self_model.preprocessor._sample_rate,
            'batch_size': batch_size,
            'shuffle': False,
            'num_workers': config.get('num_workers', min(batch_size, os.cpu_count() - 1)),
            'pin_memory': True,
            'channel_selector': config.get('channel_selector', None),
            'use_start_end_token': self_model.cfg.validation_ds.get('use_start_end_token', False),
        }
        if config.get("augmentor"):
            dl_config['augmentor'] = config.get("augmentor")

        temporary_datalayer = self_model._setup_dataloader_from_config(config=DictConfig(dl_config))
        return temporary_datalayer

    import types
    asr_model._setup_transcribe_dataloader = types.MethodType(patched_setup, asr_model)

    # WHY: CUDA graph compilation in RNNT label-looping decoder hits cu_call
    # unpacking bug with newer CUDA drivers. Disable CUDA graphs for decoding.
    from omegaconf import OmegaConf, open_dict
    with open_dict(asr_model.cfg):
        if hasattr(asr_model.cfg, 'decoding'):
            asr_model.cfg.decoding.greedy.loop_labels = False
    asr_model.change_decoding_strategy(asr_model.cfg.decoding)

    # Move to GPU if available
    device = "cuda" if torch.cuda.is_available() else "cpu"
    asr_model = asr_model.to(device)
    asr_model.eval()
    print(f"  Model loaded on {device} ({time.time() - t0:.1f}s)")
    print(f"  Model type: {type(asr_model).__name__}")

    # === Step 3: Resample audio & transcribe ===
    print(f"\n[3/3] Transcribing audio files...")
    base_dir = os.path.dirname(os.path.abspath(__file__))

    for audio_file in AUDIO_FILES:
        filepath = os.path.join(base_dir, audio_file)
        if not os.path.exists(filepath):
            print(f"\n  ⚠ File not found: {filepath}")
            continue

        # Resample 24kHz -> 16kHz
        processed_path = resample_if_needed(filepath)

        # Transcribe
        print(f"\n  Transcribing: {audio_file}")
        t0 = time.time()
        # NeMo transcribe returns list of transcriptions
        result = asr_model.transcribe([processed_path])
        elapsed = time.time() - t0

        # Handle different return types (some models return list, some return Hypothesis objects)
        if isinstance(result, list) and len(result) > 0:
            if isinstance(result[0], str):
                text = result[0]
            elif hasattr(result[0], 'text'):
                text = result[0].text
            else:
                text = str(result[0])
        else:
            text = str(result)

        # Get audio duration for RTF calculation
        info = sf.info(processed_path)
        duration = info.duration

        print(f"  {'─' * 50}")
        print(f"  File:     {audio_file}")
        print(f"  Duration: {duration:.2f}s")
        print(f"  Time:     {elapsed:.3f}s (RTF: {elapsed/duration:.3f})")
        print(f"  Text:     {text}")
        print(f"  {'─' * 50}")


if __name__ == "__main__":
    main()
