"""
Meta MMS-1B-All: Multilingual CTC ASR on 1100+ languages.
Single model with per-language adapters — switch adapters instantly.
Used for CTC scoring: get log-probabilities and decode, compare with Gemini.
Replaces per-language IndicWav2Vec (which requires gated HF access).
"""
from __future__ import annotations

import logging
from typing import Optional

import torch
import torch.nn.functional as F
from transformers import Wav2Vec2ForCTC, AutoProcessor

from ..config import (
    WAV2VEC_BATCH_SIZE, AUDIO_SAMPLE_RATE, LANG_TO_ISO3,
)
from ..ctc_score import character_error_rate

logger = logging.getLogger(__name__)

MMS_ASR_MODEL = "facebook/mms-1b-all"


class IndicWav2VecLang:
    """
    MMS-1B-All with per-language adapters.
    Single model load, adapter swap for each language. No LRU cache needed.
    """

    def __init__(self, device: str = "cuda", dtype: torch.dtype = torch.float16):
        self.device = device
        self.dtype = dtype
        self.batch_size = WAV2VEC_BATCH_SIZE
        self.model: Optional[Wav2Vec2ForCTC] = None
        self.processor: Optional[AutoProcessor] = None
        self._current_lang: str = ""

    @property
    def available_languages(self) -> set[str]:
        return set(LANG_TO_ISO3.keys())

    def load(self, hf_token: str = ""):
        logger.info(f"Loading MMS-1B-All → {self.device}")
        self.processor = AutoProcessor.from_pretrained(MMS_ASR_MODEL)
        self.model = Wav2Vec2ForCTC.from_pretrained(
            MMS_ASR_MODEL, torch_dtype=self.dtype,
        ).to(self.device).eval()
        param_mb = sum(p.numel() * p.element_size() for p in self.model.parameters()) / 1e6
        logger.info(f"MMS-1B-All loaded: ~{param_mb:.0f}MB VRAM, adapters for all target languages")

    def unload(self):
        del self.model, self.processor
        self.model = self.processor = None
        torch.cuda.empty_cache()

    def _switch_lang(self, lang_iso1: str):
        """Switch adapter + tokenizer to target language."""
        if lang_iso1 == self._current_lang:
            return
        iso3 = LANG_TO_ISO3.get(lang_iso1, lang_iso1)
        self.model.load_adapter(iso3)
        self.processor.tokenizer.set_target_lang(iso3)
        self._current_lang = lang_iso1

    @torch.inference_mode()
    def predict_batch(
        self, waveforms: list[torch.Tensor], lang_codes: list[str],
        reference_texts: Optional[list[str]] = None,
    ) -> list[dict]:
        """
        Run CTC inference + scoring. Groups by language for adapter efficiency.
        """
        # Group by language to minimize adapter swaps
        lang_groups: dict[str, list[tuple[int, torch.Tensor, Optional[str]]]] = {}
        for idx, (wav, lang) in enumerate(zip(waveforms, lang_codes)):
            ref = reference_texts[idx] if reference_texts else None
            lang_groups.setdefault(lang, []).append((idx, wav, ref))

        results = [None] * len(waveforms)

        for lang, items in lang_groups.items():
            if lang not in LANG_TO_ISO3:
                for idx, _, _ in items:
                    results[idx] = self._empty_result()
                continue

            self._switch_lang(lang)

            for batch_start in range(0, len(items), self.batch_size):
                batch_items = items[batch_start:batch_start + self.batch_size]
                batch_results = self._infer_batch(batch_items)
                for (idx, _, _), res in zip(batch_items, batch_results):
                    results[idx] = res

        for i in range(len(results)):
            if results[i] is None:
                results[i] = self._empty_result()

        return results

    def _infer_batch(
        self, items: list[tuple[int, torch.Tensor, Optional[str]]],
    ) -> list[dict]:
        raw_arrays = [item[1].numpy() for item in items]

        inputs = self.processor(
            raw_arrays, sampling_rate=AUDIO_SAMPLE_RATE,
            return_tensors="pt", padding=True,
        )
        inputs = {k: v.to(self.device, dtype=self.dtype if v.dtype.is_floating_point else v.dtype)
                  for k, v in inputs.items()}

        logits = self.model(**inputs).logits  # [B, T, vocab]
        log_probs = F.log_softmax(logits.float(), dim=-1)

        pred_ids = torch.argmax(logits, dim=-1)
        decoded_texts = self.processor.batch_decode(pred_ids)

        results = []
        for b, (idx, wav, ref_text) in enumerate(items):
            transcription = decoded_texts[b] if b < len(decoded_texts) else ""

            result = {
                "wav2vec_transcription": transcription,
                "wav2vec_ctc_raw": None,
                "wav2vec_ctc_normalized": None,
                "wav2vec_model_used": MMS_ASR_MODEL,
            }

            # Greedy confidence + CER scoring (same approach as conformer)
            seg_lp = log_probs[b]  # [T, vocab]
            greedy_conf = seg_lp.max(dim=-1).values.mean().item()

            if ref_text and transcription:
                cer = character_error_rate(ref_text, transcription)
                result["wav2vec_ctc_raw"] = round(greedy_conf, 4)
                result["wav2vec_ctc_normalized"] = round(1.0 - cer, 4)
            elif transcription:
                result["wav2vec_ctc_raw"] = round(greedy_conf, 4)

            results.append(result)

        return results

    @staticmethod
    def _empty_result() -> dict:
        return {
            "wav2vec_transcription": "",
            "wav2vec_ctc_raw": None,
            "wav2vec_ctc_normalized": None,
            "wav2vec_model_used": "",
        }
