"""
VoxLingua107 ECAPA-TDNN (SpeechBrain): LID on 107 languages + 256-dim speaker embeddings.
Architecture: ECAPA-TDNN, ~14M params.  ~200MB VRAM.
Returns per-segment: top language, confidence, top-3, speaker embedding.

Model loaded from R2 (no HuggingFace network calls at runtime).
"""
from __future__ import annotations

import logging
import os
import tarfile
from pathlib import Path
from typing import Optional

import torch
import torch.nn.functional as F

from ..config import VOXLINGUA_MODEL, VOXLINGUA_BATCH_SIZE, VOXLINGUA_LABEL_MAP
from ..audio_loader import collate_waveforms

logger = logging.getLogger(__name__)

_LOCAL_DIR = Path("/tmp/voxlingua107")
_R2_KEY = "models/voxlingua107.tar"


class VoxLinguaLID:
    """Batch LID + speaker embedding extraction with VoxLingua107."""

    def __init__(self, device: str = "cuda"):
        self.device = device
        self.model = None
        self.batch_size = VOXLINGUA_BATCH_SIZE

    def load(self, hf_token: str = ""):
        logger.info(f"Loading VoxLingua107 → {self.device}")
        save_dir = self._ensure_local()

        # SpeechBrain compat: strip deprecated use_auth_token from hf_hub_download calls
        import huggingface_hub
        _orig_download = huggingface_hub.hf_hub_download
        def _patched_download(*args, **kwargs):
            kwargs.pop("use_auth_token", None)
            return _orig_download(*args, **kwargs)
        huggingface_hub.hf_hub_download = _patched_download

        from speechbrain.inference.classifiers import EncoderClassifier
        self.model = EncoderClassifier.from_hparams(
            source=str(save_dir),
            savedir=str(save_dir),
            run_opts={"device": self.device},
        )

        huggingface_hub.hf_hub_download = _orig_download
        logger.info("VoxLingua107 loaded (~14M params)")

    @staticmethod
    def _ensure_local() -> Path:
        """Download VoxLingua107 from R2 if not already present locally."""
        if (_LOCAL_DIR / "embedding_model.ckpt").exists():
            logger.info("VoxLingua107 already cached locally")
            return _LOCAL_DIR

        import boto3
        bucket = os.getenv("R2_VALIDATION_MODEL_BUCKET", "validation-results")
        tar_path = Path("/tmp/voxlingua107.tar")

        logger.info(f"Downloading VoxLingua107 from R2: s3://{bucket}/{_R2_KEY}")
        s3 = boto3.client("s3",
            endpoint_url=os.getenv("R2_ENDPOINT_URL"),
            aws_access_key_id=os.getenv("R2_ACCESS_KEY_ID"),
            aws_secret_access_key=os.getenv("R2_SECRET_ACCESS_KEY"),
            region_name="auto",
        )
        s3.download_file(bucket, _R2_KEY, str(tar_path))
        size_mb = tar_path.stat().st_size / 1e6
        logger.info(f"Downloaded {size_mb:.0f}MB, extracting...")

        with tarfile.open(tar_path, "r:*") as tf:
            tf.extractall("/tmp", filter="data")
        tar_path.unlink(missing_ok=True)

        # Patch hyperparams.yaml to use local paths instead of HF repo references
        hp = _LOCAL_DIR / "hyperparams.yaml"
        if hp.exists():
            text = hp.read_text()
            text = text.replace(
                "pretrained_path: speechbrain/lang-id-voxlingua107-ecapa",
                f"pretrained_path: {_LOCAL_DIR}",
            )
            hp.write_text(text)

        logger.info(f"VoxLingua107 extracted to {_LOCAL_DIR}")
        return _LOCAL_DIR

    def unload(self):
        del self.model
        self.model = None
        torch.cuda.empty_cache()

    @torch.inference_mode()
    def predict_batch(
        self, waveforms: list[torch.Tensor], sample_rate: int = 16000,
    ) -> list[dict]:
        """
        Run LID + embedding extraction on a list of mono waveforms.
        Returns list of dicts with keys:
          vox_lang, vox_lang_iso1, vox_confidence, vox_top3, vox_speaker_embedding
        """
        results = []
        for i in range(0, len(waveforms), self.batch_size):
            batch_wavs = waveforms[i:i + self.batch_size]
            batch_results = self._infer_batch(batch_wavs, sample_rate)
            results.extend(batch_results)
        return results

    def _infer_batch(self, waveforms: list[torch.Tensor], sample_rate: int) -> list[dict]:
        padded, lengths = collate_waveforms(waveforms)
        padded = padded.to(self.device)
        wav_lens = lengths.float() / lengths.max().float()
        wav_lens = wav_lens.to(self.device)

        # classify_batch returns (scores, softmax_scores, index, text_label)
        out_prob, score, index, text_lab = self.model.classify_batch(padded, wav_lens)

        # Extract speaker embeddings
        embeddings = self.model.encode_batch(padded, wav_lens)  # [B, 1, 256]
        embeddings = embeddings.squeeze(1).cpu()  # [B, 256]

        # out_prob: [B, num_langs] log-softmax scores
        probs = out_prob.exp().cpu()  # linear probabilities

        results = []
        for b in range(len(waveforms)):
            row_probs = probs[b]
            top3_vals, top3_idx = row_probs.topk(min(3, row_probs.shape[0]))

            # text_lab is a list of lists: [['hi: Hindi'], ['en: English'], ...]
            raw_label = text_lab[b] if isinstance(text_lab[b], str) else text_lab[b][0]
            iso1 = self._label_to_iso1(raw_label)

            top3 = []
            for j in range(min(3, top3_idx.shape[0])):
                idx = top3_idx[j].item()
                lab = self.model.hparams.label_encoder.decode_ndim(idx)
                if isinstance(lab, list):
                    lab = lab[0]
                top3.append({
                    "lang": str(lab),
                    "conf": round(top3_vals[j].item(), 4),
                })

            raw_conf = score[b].item() if score.dim() > 0 else score.item()
            linear_conf = raw_conf if raw_conf >= 0 else torch.tensor(raw_conf).exp().item()
            results.append({
                "vox_lang": raw_label,
                "vox_lang_iso1": iso1,
                "vox_confidence": round(linear_conf, 4),
                "vox_top3": top3,
                "vox_speaker_embedding": embeddings[b].numpy().tobytes(),
            })
        return results

    @staticmethod
    def _label_to_iso1(label: str) -> str:
        """Convert VoxLingua label to ISO-1. E.g. 'hi: Hindi' → 'hi'."""
        if label in VOXLINGUA_LABEL_MAP:
            return VOXLINGUA_LABEL_MAP[label]
        code = label.split(":")[0].strip()
        return VOXLINGUA_LABEL_MAP.get(code, code)
