"""
Transcribe Hindi audio using fine-tuned Nemotron ASR model.
Model: BayAreaBoys/nemotron-hindi (fine-tuned from nvidia/nemotron-speech-streaming-en-0.6b)
Architecture: FastConformer Cache-Aware RNNT — expects 16kHz mono audio
"""

import os
import time
import torch
import torchaudio
from dotenv import load_dotenv
from huggingface_hub import hf_hub_download
import nemo.collections.asr as nemo_asr

load_dotenv()

HF_TOKEN = os.getenv("HF_TOKEN")
HF_MODEL = os.getenv("HF_MODEL")  # BayAreaBoys/nemotron-hindi
AUDIO_FILES = ["1.wav", "2.wav"]
TARGET_SR = 16000  # Nemotron expects 16kHz mono


def resample_if_needed(filepath: str, target_sr: int = TARGET_SR) -> str:
    """Resample audio to target sample rate if needed. Returns path to 16kHz file."""
    waveform, sr = torchaudio.load(filepath)

    # Convert to mono if stereo
    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)

    if sr != target_sr:
        print(f"  Resampling {filepath}: {sr}Hz -> {target_sr}Hz")
        resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr)
        waveform = resampler(waveform)
        # Save resampled version
        out_path = filepath.replace(".wav", f"_{target_sr}.wav")
        torchaudio.save(out_path, waveform, target_sr)
        return out_path

    return filepath


def main():
    print(f"=" * 60)
    print(f"Hindi Nemotron ASR Transcription")
    print(f"Model: {HF_MODEL}")
    print(f"=" * 60)

    # === Step 1: Download .nemo checkpoint from HuggingFace ===
    print(f"\n[1/3] Downloading model from HuggingFace: {HF_MODEL}")
    t0 = time.time()
    model_path = hf_hub_download(
        repo_id=HF_MODEL,
        filename="final_model.nemo",
        token=HF_TOKEN,
    )
    print(f"  Downloaded to: {model_path} ({time.time() - t0:.1f}s)")

    # === Step 2: Load NeMo ASR model ===
    print(f"\n[2/3] Loading NeMo ASR model...")
    t0 = time.time()
    asr_model = nemo_asr.models.ASRModel.restore_from(model_path)

    # Move to GPU if available
    device = "cuda" if torch.cuda.is_available() else "cpu"
    asr_model = asr_model.to(device)
    asr_model.eval()
    print(f"  Model loaded on {device} ({time.time() - t0:.1f}s)")
    print(f"  Model type: {type(asr_model).__name__}")

    # === Step 3: Resample audio & transcribe ===
    print(f"\n[3/3] Transcribing audio files...")
    base_dir = os.path.dirname(os.path.abspath(__file__))

    for audio_file in AUDIO_FILES:
        filepath = os.path.join(base_dir, audio_file)
        if not os.path.exists(filepath):
            print(f"\n  ⚠ File not found: {filepath}")
            continue

        # Resample 24kHz -> 16kHz
        processed_path = resample_if_needed(filepath)

        # Transcribe
        print(f"\n  Transcribing: {audio_file}")
        t0 = time.time()
        # NeMo transcribe returns list of transcriptions
        result = asr_model.transcribe([processed_path])
        elapsed = time.time() - t0

        # Handle different return types (some models return list, some return Hypothesis objects)
        if isinstance(result, list) and len(result) > 0:
            if isinstance(result[0], str):
                text = result[0]
            elif hasattr(result[0], 'text'):
                text = result[0].text
            else:
                text = str(result[0])
        else:
            text = str(result)

        # Get audio duration for RTF calculation
        info = torchaudio.info(processed_path)
        duration = info.num_frames / info.sample_rate

        print(f"  {'─' * 50}")
        print(f"  File:     {audio_file}")
        print(f"  Duration: {duration:.2f}s")
        print(f"  Time:     {elapsed:.3f}s (RTF: {elapsed/duration:.3f})")
        print(f"  Text:     {text}")
        print(f"  {'─' * 50}")


if __name__ == "__main__":
    main()
