"""
MMS LID-256 (Meta): Language identification on 256 languages.
Architecture: Wav2Vec2, 1B params.  ~2GB fp16 VRAM.
Returns per-segment: top language (ISO-3), confidence, top-3 predictions.
"""
from __future__ import annotations

import logging
from typing import Optional

import torch
import torch.nn.functional as F
from transformers import AutoFeatureExtractor, Wav2Vec2ForSequenceClassification

from ..config import MMS_LID_MODEL, MMS_BATCH_SIZE, ISO3_TO_LANG

logger = logging.getLogger(__name__)


class MMSLID:
    """Batch LID inference with MMS LID-256."""

    def __init__(self, device: str = "cuda", dtype: torch.dtype = torch.float16):
        self.device = device
        self.dtype = dtype
        self.model: Optional[Wav2Vec2ForSequenceClassification] = None
        self.processor: Optional[AutoFeatureExtractor] = None
        self._id2label: dict[int, str] = {}
        self.batch_size = MMS_BATCH_SIZE

    def load(self, hf_token: str = ""):
        logger.info(f"Loading MMS LID-256 → {self.device} ({self.dtype})")
        self.processor = AutoFeatureExtractor.from_pretrained(MMS_LID_MODEL)
        self.model = Wav2Vec2ForSequenceClassification.from_pretrained(
            MMS_LID_MODEL, torch_dtype=self.dtype,
        ).to(self.device).eval()
        self._id2label = self.model.config.id2label
        vram_mb = sum(p.numel() * p.element_size() for p in self.model.parameters()) / 1e6
        logger.info(f"MMS LID loaded: {len(self._id2label)} languages, ~{vram_mb:.0f}MB VRAM")

    def unload(self):
        del self.model, self.processor
        self.model = self.processor = None
        torch.cuda.empty_cache()

    @torch.inference_mode()
    def predict_batch(
        self, waveforms: list[torch.Tensor], sample_rate: int = 16000,
    ) -> list[dict]:
        """
        Run LID on a list of mono waveforms (each [samples] float32).
        Returns list of dicts with keys:
          mms_lang_iso3, mms_lang_iso1, mms_confidence, mms_top3
        """
        results = []
        for i in range(0, len(waveforms), self.batch_size):
            batch_wavs = waveforms[i:i + self.batch_size]
            batch_results = self._infer_batch(batch_wavs, sample_rate)
            results.extend(batch_results)
        return results

    def _infer_batch(self, waveforms: list[torch.Tensor], sample_rate: int) -> list[dict]:
        # Processor expects numpy arrays
        raw_arrays = [w.numpy() for w in waveforms]
        inputs = self.processor(
            raw_arrays, sampling_rate=sample_rate,
            return_tensors="pt", padding=True,
        )
        inputs = {k: v.to(self.device, dtype=self.dtype if v.dtype.is_floating_point else v.dtype)
                  for k, v in inputs.items()}

        logits = self.model(**inputs).logits  # [B, num_langs]
        probs = F.softmax(logits.float(), dim=-1)  # back to fp32 for softmax

        results = []
        for b in range(probs.shape[0]):
            row_probs = probs[b]
            top3_vals, top3_idx = row_probs.topk(3)

            top1_iso3 = self._id2label[top3_idx[0].item()]
            top1_conf = top3_vals[0].item()
            top1_iso1 = ISO3_TO_LANG.get(top1_iso3, top1_iso3)

            top3 = [
                {"lang": self._id2label[top3_idx[j].item()],
                 "conf": round(top3_vals[j].item(), 4)}
                for j in range(3)
            ]

            results.append({
                "mms_lang_iso3": top1_iso3,
                "mms_lang_iso1": top1_iso1,
                "mms_confidence": round(top1_conf, 4),
                "mms_top3": top3,
            })
        return results
