"""
MMS LID-256 (Meta): Language identification on 256 languages.
Architecture: Wav2Vec2, 1B params.  ~2GB fp16 VRAM.
Returns per-segment: top language (ISO-3), confidence, top-3 predictions.

Model loaded from R2 (no HuggingFace network calls at runtime).
"""
from __future__ import annotations

import logging
import os
import tarfile
from pathlib import Path
from typing import Optional

import torch
import torch.nn.functional as F
from transformers import AutoFeatureExtractor, Wav2Vec2ForSequenceClassification

from ..config import MMS_LID_MODEL, MMS_BATCH_SIZE, ISO3_TO_LANG

logger = logging.getLogger(__name__)

_LOCAL_DIR = Path("/tmp/mms-lid-256")
_R2_KEY = "models/mms-lid-256.tar"


class MMSLID:
    """Batch LID inference with MMS LID-256."""

    def __init__(self, device: str = "cuda", dtype: torch.dtype = torch.float16):
        self.device = device
        self.dtype = dtype
        self.model: Optional[Wav2Vec2ForSequenceClassification] = None
        self.processor: Optional[AutoFeatureExtractor] = None
        self._id2label: dict[int, str] = {}
        self.batch_size = MMS_BATCH_SIZE

    def load(self, hf_token: str = ""):
        logger.info(f"Loading MMS LID-256 → {self.device} ({self.dtype})")
        local = self._ensure_local()
        self.processor = AutoFeatureExtractor.from_pretrained(local, local_files_only=True)
        self.model = Wav2Vec2ForSequenceClassification.from_pretrained(
            local, torch_dtype=self.dtype, local_files_only=True,
        ).to(self.device).eval()
        self._id2label = self.model.config.id2label
        vram_mb = sum(p.numel() * p.element_size() for p in self.model.parameters()) / 1e6
        logger.info(f"MMS LID loaded: {len(self._id2label)} languages, ~{vram_mb:.0f}MB VRAM")

    @staticmethod
    def _ensure_local() -> str:
        """Download MMS LID from R2 if not already present locally."""
        if (_LOCAL_DIR / "config.json").exists():
            logger.info("MMS LID already cached locally")
            return str(_LOCAL_DIR)

        import boto3
        # Keep model artifacts decoupled from the shard output bucket so recover
        # runs can write results to a fresh bucket without breaking model loads.
        bucket = os.getenv("R2_VALIDATION_MODEL_BUCKET", "validation-results")
        tar_path = Path("/tmp/mms-lid-256.tar")

        logger.info(f"Downloading MMS LID from R2: s3://{bucket}/{_R2_KEY}")
        s3 = boto3.client("s3",
            endpoint_url=os.getenv("R2_ENDPOINT_URL"),
            aws_access_key_id=os.getenv("R2_ACCESS_KEY_ID"),
            aws_secret_access_key=os.getenv("R2_SECRET_ACCESS_KEY"),
            region_name="auto",
        )
        s3.download_file(bucket, _R2_KEY, str(tar_path))
        size_mb = tar_path.stat().st_size / 1e6
        logger.info(f"Downloaded {size_mb:.0f}MB, extracting...")

        with tarfile.open(tar_path, "r:*") as tf:
            tf.extractall("/tmp", filter="data")
        tar_path.unlink(missing_ok=True)

        logger.info(f"MMS LID extracted to {_LOCAL_DIR}")
        return str(_LOCAL_DIR)

    def unload(self):
        del self.model, self.processor
        self.model = self.processor = None
        torch.cuda.empty_cache()

    @torch.inference_mode()
    def predict_batch(
        self, waveforms: list[torch.Tensor], sample_rate: int = 16000,
    ) -> list[dict]:
        """
        Run LID on a list of mono waveforms (each [samples] float32).
        Returns list of dicts with keys:
          mms_lang_iso3, mms_lang_iso1, mms_confidence, mms_top3
        """
        results = []
        for i in range(0, len(waveforms), self.batch_size):
            batch_wavs = waveforms[i:i + self.batch_size]
            batch_results = self._infer_batch(batch_wavs, sample_rate)
            results.extend(batch_results)
        return results

    def _infer_batch(self, waveforms: list[torch.Tensor], sample_rate: int) -> list[dict]:
        # Processor expects numpy arrays
        raw_arrays = [w.numpy() for w in waveforms]
        inputs = self.processor(
            raw_arrays, sampling_rate=sample_rate,
            return_tensors="pt", padding=True,
        )
        inputs = {k: v.to(self.device, dtype=self.dtype if v.dtype.is_floating_point else v.dtype)
                  for k, v in inputs.items()}

        logits = self.model(**inputs).logits  # [B, num_langs]
        probs = F.softmax(logits.float(), dim=-1)  # back to fp32 for softmax

        results = []
        for b in range(probs.shape[0]):
            row_probs = probs[b]
            top3_vals, top3_idx = row_probs.topk(3)

            top1_iso3 = self._id2label[top3_idx[0].item()]
            top1_conf = top3_vals[0].item()
            top1_iso1 = ISO3_TO_LANG.get(top1_iso3, top1_iso3)

            top3 = [
                {"lang": self._id2label[top3_idx[j].item()],
                 "conf": round(top3_vals[j].item(), 4)}
                for j in range(3)
            ]

            results.append({
                "mms_lang_iso3": top1_iso3,
                "mms_lang_iso1": top1_iso1,
                "mms_confidence": round(top1_conf, 4),
                "mms_top3": top3,
            })
        return results
