"""
AI4Bharat IndicConformer 600M Multilingual: CTC + RNNT ASR on 22 Indic languages.
Uses trust_remote_code=True (custom HF ONNX-based model).
Provides CTC transcription + confidence scoring via greedy path logprobs + CER vs Gemini.
"""
from __future__ import annotations

import logging
from typing import Optional

import numpy as np
import torch

from ..config import (
    CONFORMER_MULTI_MODEL, CONFORMER_BATCH_SIZE, CONFORMER_LANG_CODES,
)
from ..ctc_score import character_error_rate

logger = logging.getLogger(__name__)


class IndicConformerMulti:
    """
    IndicConformer 600M multilingual wrapper.

    Internal model uses ONNX Runtime (encoder + ctc_decoder).
    We extract greedy CTC transcription + greedy path confidence for scoring.
    CTC scoring against Gemini text uses CER (Character Error Rate).
    
    Why CER and not direct CTC log-likelihood:
      The model's vocabulary is SentencePiece subwords. Without the SPM model,
      we can't tokenize Gemini's text into the correct subword IDs for CTC scoring.
      CER gives the same signal: low CER = conformer agrees with Gemini = high quality.
    """

    def __init__(self, device: str = "cuda", dtype: torch.dtype = torch.float16):
        self.device = device
        self.dtype = dtype
        self.model = None
        self.batch_size = CONFORMER_BATCH_SIZE

    def load(self, hf_token: str = ""):
        logger.info(f"Loading IndicConformer 600M multilingual → {self.device}")
        from transformers import AutoModel

        kwargs = {"trust_remote_code": True}
        if hf_token:
            kwargs["token"] = hf_token

        self.model = AutoModel.from_pretrained(CONFORMER_MULTI_MODEL, **kwargs)

        # Patch encode() to avoid TorchScript-CUDA conflict with PyTorch models.
        # The TorchScript preprocessor fights with MMS/VoxLingua for the CUDA context.
        # Fix: reload preprocessor on CPU, run preprocessing on CPU, then ONNX on GPU.
        import types
        preprocessor_cpu = torch.jit.load(
            f'{self.model.config.ts_folder}/assets/preprocessor.ts',
            map_location='cpu',
        )
        onnx_encoder = self.model.models['encoder']

        def _encode_safe(self_model, wav):
            audio_signal, length = preprocessor_cpu(
                input_signal=wav.cpu(),
                length=torch.tensor([wav.shape[-1]]),
            )
            outputs, enc_lengths = onnx_encoder.run(
                ['outputs', 'encoded_lengths'],
                {'audio_signal': audio_signal.numpy(), 'length': length.numpy()},
            )
            return outputs, enc_lengths

        self.model.encode = types.MethodType(_encode_safe, self.model)

        logger.info(
            f"IndicConformer loaded: {len(self.model.vocab)} languages, "
            f"vocab sizes: {', '.join(f'{k}={len(v)}' for k,v in list(self.model.vocab.items())[:3])}..."
        )

    def unload(self):
        if self.model:
            del self.model
            self.model = None
        torch.cuda.empty_cache()

    @torch.inference_mode()
    def predict_batch(
        self, waveforms: list[torch.Tensor], lang_codes: list[str],
        reference_texts: Optional[list[str]] = None,
    ) -> list[dict]:
        """Run CTC inference + scoring. Processes one at a time (ONNX doesn't batch)."""
        results = []
        for i, (wav, lang) in enumerate(zip(waveforms, lang_codes)):
            ref = reference_texts[i] if reference_texts else None
            results.append(self._infer_single(wav, lang, ref))
        return results

    def _infer_single(
        self, waveform: torch.Tensor, lang_code: str,
        reference_text: Optional[str] = None,
    ) -> dict:
        result = {
            "conformer_multi_transcription": "",
            "conformer_multi_ctc_raw": None,
            "conformer_multi_ctc_normalized": None,
        }

        if lang_code not in CONFORMER_LANG_CODES or lang_code not in self.model.vocab:
            return result

        try:
            wav = waveform.unsqueeze(0)  # [1, samples]

            # Sync CUDA before ONNX RT — prevents context conflict with PyTorch models
            if torch.cuda.is_available():
                torch.cuda.synchronize()

            # Encode → CTC decode (get both text and logprobs)
            encoder_outputs, encoded_lengths = self.model.encode(wav)

            raw_logprobs = self.model.models['ctc_decoder'].run(
                ['logprobs'], {'encoder_output': encoder_outputs}
            )[0]

            lang_mask = self.model.language_masks[lang_code]
            masked_logprobs = torch.from_numpy(
                raw_logprobs[:, :, lang_mask]
            ).log_softmax(dim=-1)

            T = int(encoded_lengths[0]) if isinstance(encoded_lengths, np.ndarray) else int(encoded_lengths)
            lp = masked_logprobs[0, :T]  # [T, vocab_size]

            # Greedy decode
            blank_id = self.model.config.BLANK_ID
            vocab = self.model.vocab[lang_code]
            indices = torch.argmax(lp, dim=-1)
            collapsed = torch.unique_consecutive(indices, dim=-1)
            hyp = ''.join([vocab[x.item()] for x in collapsed if x.item() != blank_id])
            transcription = hyp.replace('▁', ' ').strip()
            result["conformer_multi_transcription"] = transcription

            # Greedy path confidence: mean of max log-probs per frame
            # Higher = model is more confident about its output
            max_logprobs = lp.max(dim=-1).values  # [T]
            greedy_confidence = max_logprobs.mean().item()  # avg per frame

            if reference_text and transcription:
                # CER between conformer output and Gemini text
                cer = character_error_rate(reference_text, transcription)
                # Score: 1 - CER (higher = better agreement)
                # Raw: negative CER (to match "higher is better" convention)
                result["conformer_multi_ctc_raw"] = round(greedy_confidence, 4)
                result["conformer_multi_ctc_normalized"] = round(1.0 - cer, 4)
            elif transcription:
                result["conformer_multi_ctc_raw"] = round(greedy_confidence, 4)

        except Exception as e:
            logger.warning(f"IndicConformer failed for lang={lang_code}: {e}")

        return result
