"""
AI4Bharat IndicConformer 600M Multilingual: CTC + RNNT ASR on 22 Indic languages.
Uses trust_remote_code=True (custom HF ONNX-based model).
Provides CTC transcription + confidence scoring via greedy path logprobs + CER vs Gemini.

GPU FIX: The original TorchScript preprocessor (preprocessor.ts) conflicts with PyTorch
models (MMS LID, VoxLingua) sharing the same CUDA context. We replace it with a pure
PyTorch mel-spectrogram that runs on GPU alongside everything else.
Parameters extracted from the TorchScript graph: n_fft=512, hop=160, win=400, 80 mels,
preemphasis=0.97, per-channel mean/std normalization. Validated to match within 1e-6.
"""
from __future__ import annotations

import json
import logging
from pathlib import Path
from typing import Optional

import numpy as np
import torch
import torch.nn.functional as F

from ..config import (
    CONFORMER_MULTI_MODEL, CONFORMER_BATCH_SIZE, CONFORMER_LANG_CODES,
)
from ..ctc_score import character_error_rate

logger = logging.getLogger(__name__)

# Mel spectrogram parameters (from TorchScript graph inspection)
_N_FFT = 512
_HOP_LENGTH = 160
_WIN_LENGTH = 400
_N_MELS = 80
_PREEMPH = 0.97
_LOG_EPS = 5.960464477539063e-08
_NORM_EPS = 1e-05
_PAD = 256  # reflect padding before STFT


class _MelPreprocessor(torch.nn.Module):
    """
    Pure PyTorch replacement for the NeMo TorchScript AudioToMelSpectrogramPreprocessor.
    Runs on CUDA without TorchScript context conflicts.
    """

    def __init__(self, mel_fb: torch.Tensor, stft_window: torch.Tensor):
        super().__init__()
        self.register_buffer("mel_fb", mel_fb)        # [n_fft//2+1, n_mels]
        self.register_buffer("window", stft_window)    # [win_length]

    def forward(self, wav: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            wav: [B, samples] raw audio at 16kHz
        Returns:
            features: [B, n_mels, T] normalized log-mel spectrogram
            lengths:  [B] number of valid frames
        """
        B = wav.shape[0]

        # Pre-emphasis: x[n] - 0.97 * x[n-1]
        padded = F.pad(wav, (1, 0))
        emphasized = wav - _PREEMPH * padded[:, :-1]

        # Reflect-pad 256 samples each side (matches NeMo)
        emphasized = F.pad(emphasized, (_PAD, _PAD), mode="reflect")

        # STFT (per-sample, since stft doesn't batch well with varying lengths)
        specs = []
        for b in range(B):
            s = torch.stft(
                emphasized[b], n_fft=_N_FFT, hop_length=_HOP_LENGTH,
                win_length=_WIN_LENGTH, window=self.window,
                center=False, normalized=False, return_complex=True,
            )
            specs.append(s)
        spec = torch.stack(specs)  # [B, freq, T]

        # Power spectrum → mel filterbank → log
        power = spec.abs().pow(2)                                      # [B, 257, T]
        mel = torch.matmul(power.transpose(1, 2), self.mel_fb)        # [B, T, 80]
        mel = mel.transpose(1, 2)                                      # [B, 80, T]
        log_mel = torch.log(mel + _LOG_EPS)

        # Per-channel standardization (mean=0, std=1 over time axis)
        mean = log_mel.mean(dim=-1, keepdim=True)
        std = log_mel.std(dim=-1, keepdim=True)
        features = (log_mel - mean) / (std + _NORM_EPS)

        lengths = torch.tensor([wav.shape[-1] // _HOP_LENGTH + 1] * B,
                               dtype=torch.long, device=wav.device)
        return features, lengths


class IndicConformerMulti:
    """
    IndicConformer 600M multilingual wrapper.

    Internal model uses ONNX Runtime (encoder + ctc_decoder).
    Preprocessor is a pure PyTorch mel-spectrogram on GPU (replaces TorchScript).
    """

    def __init__(self, device: str = "cuda", dtype: torch.dtype = torch.float16):
        self.device = device
        self.dtype = dtype
        self.model = None
        self.preprocessor: Optional[_MelPreprocessor] = None
        self.batch_size = CONFORMER_BATCH_SIZE

    def load(self, hf_token: str = ""):
        logger.info(f"Loading IndicConformer 600M multilingual → {self.device}")
        from transformers import AutoModel

        kwargs = {"trust_remote_code": True}
        if hf_token:
            kwargs["token"] = hf_token

        self.model = AutoModel.from_pretrained(CONFORMER_MULTI_MODEL, **kwargs)

        # Extract mel filterbank and STFT window from the TorchScript preprocessor,
        # then build a pure PyTorch replacement that runs on GPU without conflicts.
        ts_path = f'{self.model.config.ts_folder}/assets/preprocessor.ts'
        pp_cpu = torch.jit.load(ts_path, map_location='cpu')

        mel_fb, stft_window = None, None
        for node in pp_cpu.graph.nodes():
            if node.kind() == 'prim::Constant':
                try:
                    val = node.output().toIValue()
                    if isinstance(val, torch.Tensor):
                        if val.shape == (_N_FFT // 2 + 1, _N_MELS):
                            mel_fb = val
                        elif val.shape == (_WIN_LENGTH,):
                            stft_window = val
                except Exception:
                    pass

        if mel_fb is None or stft_window is None:
            logger.warning("Could not extract mel filterbank from TorchScript, falling back to CPU")
            self._load_cpu_fallback()
            return

        self.preprocessor = _MelPreprocessor(mel_fb, stft_window).to(self.device).eval()
        del pp_cpu

        logger.info(
            f"IndicConformer loaded (GPU preprocessor): {len(self.model.vocab)} languages, "
            f"vocab sizes: {', '.join(f'{k}={len(v)}' for k,v in list(self.model.vocab.items())[:3])}..."
        )

    def _load_cpu_fallback(self):
        """Fallback: use TorchScript preprocessor on CPU if extraction fails."""
        import types
        pp_cpu = torch.jit.load(
            f'{self.model.config.ts_folder}/assets/preprocessor.ts',
            map_location='cpu',
        )
        onnx_encoder = self.model.models['encoder']

        def _encode_safe(self_model, wav):
            audio_signal, length = pp_cpu(
                input_signal=wav.cpu(),
                length=torch.tensor([wav.shape[-1]]),
            )
            outputs, enc_lengths = onnx_encoder.run(
                ['outputs', 'encoded_lengths'],
                {'audio_signal': audio_signal.numpy(), 'length': length.numpy()},
            )
            return outputs, enc_lengths

        self.model.encode = types.MethodType(_encode_safe, self.model)
        logger.info("IndicConformer loaded (CPU fallback preprocessor)")

    def unload(self):
        if self.model:
            del self.model
            self.model = None
        if self.preprocessor:
            del self.preprocessor
            self.preprocessor = None
        torch.cuda.empty_cache()

    @torch.inference_mode()
    def predict_batch(
        self, waveforms: list[torch.Tensor], lang_codes: list[str],
        reference_texts: Optional[list[str]] = None,
    ) -> list[dict]:
        """
        Batched CTC inference + scoring.
        Preprocesses all segments on GPU, then runs ONNX encoder in batches
        of self.batch_size for massive throughput improvement (~87 segs/s at bs=16).
        """
        if self.preprocessor is None:
            return [self._infer_single_cpu(w, l, reference_texts[i] if reference_texts else None)
                    for i, (w, l) in enumerate(zip(waveforms, lang_codes))]

        # Filter to supported languages, track original indices
        items: list[tuple[int, torch.Tensor, str, str]] = []
        results: list[dict] = [self._empty_result()] * len(waveforms)

        for i, (wav, lang) in enumerate(zip(waveforms, lang_codes)):
            if lang in CONFORMER_LANG_CODES and lang in self.model.vocab:
                ref = reference_texts[i] if reference_texts else ""
                items.append((i, wav, lang, ref))

        if not items:
            return results

        # Step 1: GPU preprocess all segments
        features_list: list[torch.Tensor] = []
        for _, wav, _, _ in items:
            feat, _ = self.preprocessor(wav.unsqueeze(0).to(self.device))
            features_list.append(feat.squeeze(0))  # [80, T_i]

        # Step 2: Batch through ONNX encoder + CTC decoder
        blank_id = self.model.config.BLANK_ID
        encoder_session = self.model.models['encoder']
        ctc_session = self.model.models['ctc_decoder']

        for batch_start in range(0, len(items), self.batch_size):
            batch_end = min(batch_start + self.batch_size, len(items))
            batch_feats = features_list[batch_start:batch_end]
            batch_items = items[batch_start:batch_end]

            self._run_onnx_batch(
                batch_feats, batch_items, encoder_session, ctc_session,
                blank_id, results,
            )

        return results

    def _run_onnx_batch(self, batch_feats, batch_items, encoder_session, ctc_session,
                         blank_id, results):
        """Run ONNX encoder+CTC on a batch. On OOM, halve batch and retry recursively."""
        bs = len(batch_feats)
        max_T = max(f.shape[1] for f in batch_feats)
        padded = torch.zeros(bs, _N_MELS, max_T, device=self.device)
        frame_lengths = []
        for j, f in enumerate(batch_feats):
            padded[j, :, :f.shape[1]] = f
            frame_lengths.append(f.shape[1])

        try:
            enc_out, enc_len = encoder_session.run(
                ['outputs', 'encoded_lengths'],
                {'audio_signal': padded.cpu().numpy(),
                 'length': np.array(frame_lengths, dtype=np.int64)},
            )
            raw_logprobs = ctc_session.run(
                ['logprobs'], {'encoder_output': enc_out}
            )[0]
        except Exception as e:
            if 'allocate memory' in str(e) or 'RUNTIME_EXCEPTION' in str(e):
                if bs <= 1:
                    logger.warning(f"Conformer OOM even at batch_size=1, skipping {bs} segments")
                    return
                half = bs // 2
                logger.warning(f"Conformer OOM at batch_size={bs}, retrying with {half}")
                self._run_onnx_batch(
                    batch_feats[:half], batch_items[:half],
                    encoder_session, ctc_session, blank_id, results)
                self._run_onnx_batch(
                    batch_feats[half:], batch_items[half:],
                    encoder_session, ctc_session, blank_id, results)
                return
            raise

        for j, (orig_idx, _, lang, ref_text) in enumerate(batch_items):
            try:
                lang_mask = self.model.language_masks[lang]
                masked = torch.from_numpy(
                    raw_logprobs[j:j+1, :, lang_mask]
                ).log_softmax(dim=-1)

                T = int(enc_len[j])
                lp = masked[0, :T]

                vocab = self.model.vocab[lang]
                indices = torch.argmax(lp, dim=-1)
                collapsed = torch.unique_consecutive(indices, dim=-1)
                hyp = ''.join([vocab[x.item()] for x in collapsed if x.item() != blank_id])
                transcription = hyp.replace('▁', ' ').strip()

                greedy_conf = lp.max(dim=-1).values.mean().item()

                r = {
                    "conformer_multi_transcription": transcription,
                    "conformer_multi_ctc_raw": round(greedy_conf, 4),
                    "conformer_multi_ctc_normalized": None,
                }
                if ref_text and transcription:
                    cer = character_error_rate(ref_text, transcription)
                    r["conformer_multi_ctc_normalized"] = round(1.0 - cer, 4)

                results[orig_idx] = r
            except Exception as e:
                logger.warning(f"Conformer decode failed for item {j}: {e}")

    def _infer_single_cpu(
        self, waveform: torch.Tensor, lang_code: str,
        reference_text: Optional[str] = None,
    ) -> dict:
        """CPU fallback path — used when GPU preprocessor extraction fails."""
        result = self._empty_result()
        if lang_code not in CONFORMER_LANG_CODES or lang_code not in self.model.vocab:
            return result
        try:
            wav = waveform.unsqueeze(0)
            encoder_outputs, encoded_lengths = self.model.encode(wav)
            raw_logprobs = self.model.models['ctc_decoder'].run(
                ['logprobs'], {'encoder_output': encoder_outputs}
            )[0]
            lang_mask = self.model.language_masks[lang_code]
            masked = torch.from_numpy(raw_logprobs[:, :, lang_mask]).log_softmax(dim=-1)
            T = int(encoded_lengths[0]) if isinstance(encoded_lengths, np.ndarray) else int(encoded_lengths)
            lp = masked[0, :T]
            blank_id = self.model.config.BLANK_ID
            vocab = self.model.vocab[lang_code]
            indices = torch.argmax(lp, dim=-1)
            collapsed = torch.unique_consecutive(indices, dim=-1)
            hyp = ''.join([vocab[x.item()] for x in collapsed if x.item() != blank_id])
            transcription = hyp.replace('▁', ' ').strip()
            result["conformer_multi_transcription"] = transcription
            greedy_conf = lp.max(dim=-1).values.mean().item()
            result["conformer_multi_ctc_raw"] = round(greedy_conf, 4)
            if reference_text and transcription:
                cer = character_error_rate(reference_text, transcription)
                result["conformer_multi_ctc_normalized"] = round(1.0 - cer, 4)
        except Exception as e:
            logger.warning(f"IndicConformer CPU failed for lang={lang_code}: {e}")
        return result

    @staticmethod
    def _empty_result() -> dict:
        return {
            "conformer_multi_transcription": "",
            "conformer_multi_ctc_raw": None,
            "conformer_multi_ctc_normalized": None,
        }
