"""
English CTC model: facebook/wav2vec2-large-960h-lv60-self
315M params, 1.9% WER on LibriSpeech, character-level vocab.
Enables proper CTC log-likelihood scoring P(text|audio) for English segments.

Replaces MMS-1B-All which was redundant with IndicConformer for Indic languages.
"""
from __future__ import annotations

import logging
import re
from typing import Optional

import torch
import torch.nn.functional as F
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor

from ..config import WAV2VEC_BATCH_SIZE, AUDIO_SAMPLE_RATE
from ..ctc_score import compute_ctc_score, character_error_rate

logger = logging.getLogger(__name__)

ENGLISH_CTC_MODEL = "facebook/wav2vec2-large-960h-lv60-self"


class EnglishCTC:
    """
    wav2vec2-large English CTC model.
    Provides both decoded transcription AND proper CTC log-likelihood scoring.
    """

    def __init__(self, device: str = "cuda", dtype: torch.dtype = torch.float16):
        self.device = device
        self.dtype = dtype
        self.batch_size = WAV2VEC_BATCH_SIZE
        self.model: Optional[Wav2Vec2ForCTC] = None
        self.processor: Optional[Wav2Vec2Processor] = None

    @property
    def available_languages(self) -> set[str]:
        return {"en"}

    def load(self, hf_token: str = ""):
        logger.info(f"Loading wav2vec2-large English CTC → {self.device}")
        self.processor = Wav2Vec2Processor.from_pretrained(ENGLISH_CTC_MODEL)
        self.model = Wav2Vec2ForCTC.from_pretrained(
            ENGLISH_CTC_MODEL, torch_dtype=self.dtype,
        ).to(self.device).eval()
        param_mb = sum(p.numel() * p.element_size() for p in self.model.parameters()) / 1e6
        logger.info(f"wav2vec2-large English CTC loaded: ~{param_mb:.0f}MB VRAM, vocab={self.processor.tokenizer.vocab_size}")

    def unload(self):
        del self.model, self.processor
        self.model = self.processor = None
        torch.cuda.empty_cache()

    @torch.inference_mode()
    def predict_batch(
        self, waveforms: list[torch.Tensor], lang_codes: list[str],
        reference_texts: Optional[list[str]] = None,
    ) -> list[dict]:
        """
        Run English CTC on segments. Non-English segments get empty results.
        Groups English segments for efficient batching.
        """
        en_indices = [i for i, lang in enumerate(lang_codes) if lang == "en"]
        results = [self._empty_result()] * len(waveforms)

        if not en_indices:
            return results

        for batch_start in range(0, len(en_indices), self.batch_size):
            batch_idx = en_indices[batch_start:batch_start + self.batch_size]
            batch_wavs = [waveforms[i] for i in batch_idx]
            batch_refs = [reference_texts[i] if reference_texts else None for i in batch_idx]
            batch_results = self._infer_batch(batch_wavs, batch_refs)
            for j, idx in enumerate(batch_idx):
                results[idx] = batch_results[j]

        return results

    def _infer_batch(
        self, waveforms: list[torch.Tensor],
        reference_texts: list[Optional[str]],
    ) -> list[dict]:
        raw_arrays = [w.numpy() for w in waveforms]
        inputs = self.processor(
            raw_arrays, sampling_rate=AUDIO_SAMPLE_RATE,
            return_tensors="pt", padding=True,
        )
        inputs = {k: v.to(self.device, dtype=self.dtype if v.dtype.is_floating_point else v.dtype)
                  for k, v in inputs.items()}

        logits = self.model(**inputs).logits  # [B, T, vocab=32]
        log_probs = F.log_softmax(logits.float(), dim=-1)

        pred_ids = torch.argmax(logits, dim=-1)
        decoded_texts = self.processor.batch_decode(pred_ids)

        results = []
        for b in range(len(waveforms)):
            transcription = decoded_texts[b] if b < len(decoded_texts) else ""
            ref_text = reference_texts[b]

            result = {
                "wav2vec_transcription": transcription,
                "wav2vec_ctc_raw": None,
                "wav2vec_ctc_normalized": None,
                "wav2vec_model_used": ENGLISH_CTC_MODEL,
            }

            seg_lp = log_probs[b]  # [T, 32]

            # Proper CTC log-likelihood scoring for English
            if ref_text:
                tokens = self._tokenize_english(ref_text)
                if tokens:
                    T = seg_lp.shape[0]
                    if T >= len(tokens):
                        raw, norm = compute_ctc_score(seg_lp, tokens, blank_id=0)
                        result["wav2vec_ctc_raw"] = raw
                        result["wav2vec_ctc_normalized"] = norm
                    else:
                        # Fallback to CER when T < S
                        if transcription:
                            cer = character_error_rate(ref_text, transcription)
                            greedy_conf = seg_lp.max(dim=-1).values.mean().item()
                            result["wav2vec_ctc_raw"] = round(greedy_conf, 4)
                            result["wav2vec_ctc_normalized"] = round(1.0 - cer, 4)
                elif transcription:
                    cer = character_error_rate(ref_text, transcription)
                    greedy_conf = seg_lp.max(dim=-1).values.mean().item()
                    result["wav2vec_ctc_raw"] = round(greedy_conf, 4)
                    result["wav2vec_ctc_normalized"] = round(1.0 - cer, 4)

            results.append(result)

        return results

    def _tokenize_english(self, text: str) -> list[int]:
        """
        Convert English text to token IDs for CTC scoring.
        wav2vec2-large uses uppercase character-level vocab (A-Z, space='|').
        Strips non-alpha characters, uppercases, maps to token IDs.
        """
        # Keep only letters and spaces, uppercase
        cleaned = re.sub(r"[^a-zA-Z\s]", "", text).upper().strip()
        cleaned = re.sub(r"\s+", " ", cleaned)

        if not cleaned:
            return []

        tokens = []
        for ch in cleaned:
            if ch == " ":
                tid = self.processor.tokenizer.word_delimiter_token_id
            else:
                tid = self.processor.tokenizer.convert_tokens_to_ids(ch)
            if tid is not None and tid != self.processor.tokenizer.unk_token_id:
                tokens.append(tid)

        return tokens

    @staticmethod
    def _empty_result() -> dict:
        return {
            "wav2vec_transcription": "",
            "wav2vec_ctc_raw": None,
            "wav2vec_ctc_normalized": None,
            "wav2vec_model_used": "",
        }
