"""
Per-video validation pipeline: orchestrates all models over a video's segments.
Downloads tar → loads audio → runs LID → runs CTC → returns rich segment data.
Maximum GPU utilization via batched inference and parallel model execution.
"""
from __future__ import annotations

import logging
import time
from dataclasses import dataclass, field
from typing import Optional

import torch

from .config import ValidationConfig, TARGET_LANGUAGES, CONFORMER_LANG_CODES
from .audio_loader import SegmentData
from .models.mms_lid import MMSLID
from .models.voxlingua import VoxLinguaLID
from .models.conformer_multi import IndicConformerMulti
from .models.wav2vec_lang import EnglishCTC

logger = logging.getLogger(__name__)


@dataclass
class SegmentResult:
    """Full validation result for one segment."""
    video_id: str
    segment_file: str
    duration_s: float

    # Gemini data (from transcription JSON)
    gemini_lang: str = ""
    gemini_transcription: str = ""
    gemini_tagged: str = ""
    gemini_quality_score: float = 0.0
    speaker_info: str = ""

    # MMS LID-256
    mms_lang_iso3: str = ""
    mms_lang_iso1: str = ""
    mms_confidence: float = 0.0
    mms_top3: str = ""  # JSON string

    # VoxLingua107
    vox_lang: str = ""
    vox_lang_iso1: str = ""
    vox_confidence: float = 0.0
    vox_top3: str = ""  # JSON string
    vox_speaker_embedding: bytes = b""

    # IndicConformer multilingual CTC
    conformer_multi_transcription: str = ""
    conformer_multi_ctc_raw: Optional[float] = None
    conformer_multi_ctc_normalized: Optional[float] = None

    # IndicWav2Vec per-language CTC
    wav2vec_transcription: str = ""
    wav2vec_ctc_raw: Optional[float] = None
    wav2vec_ctc_normalized: Optional[float] = None
    wav2vec_model_used: str = ""

    # Consensus
    lid_consensus: bool = False
    lid_agree_count: int = 0
    consensus_lang: str = ""


class ValidationPipeline:
    """Orchestrates model inference over all segments of a video."""

    def __init__(self, config: ValidationConfig):
        self.config = config
        self.mms: Optional[MMSLID] = None
        self.vox: Optional[VoxLinguaLID] = None
        self.conformer: Optional[IndicConformerMulti] = None
        self.wav2vec: Optional[EnglishCTC] = None
        self._loaded = False

    def load_models(self):
        """Load all enabled models to GPU."""
        device = "cuda" if torch.cuda.is_available() else "cpu"
        hf = self.config.hf_token
        t0 = time.time()

        if self.config.enable_mms_lid:
            self.mms = MMSLID(device=device)
            self.mms.load(hf)

        if self.config.enable_voxlingua:
            self.vox = VoxLinguaLID(device=device)
            self.vox.load(hf)

        if self.config.enable_conformer_multi:
            self.conformer = IndicConformerMulti(device=device)
            self.conformer.load(hf)

        if self.config.enable_wav2vec_lang:
            self.wav2vec = EnglishCTC(device=device)
            self.wav2vec.load(hf)

        self._loaded = True
        elapsed = time.time() - t0
        self._log_gpu_usage()
        logger.info(f"All models loaded in {elapsed:.1f}s")

    def unload_models(self):
        for model in [self.mms, self.vox, self.conformer, self.wav2vec]:
            if model:
                model.unload()
        self._loaded = False

    def process_video(
        self, video_id: str, segments: list[SegmentData],
    ) -> list[SegmentResult]:
        """
        Run all validation models over a video's segments.
        Returns one SegmentResult per segment with all metrics populated.
        """
        if not segments:
            return []

        t0 = time.time()
        n = len(segments)
        logger.info(f"[{video_id}] Processing {n} segments")

        waveforms = [s.waveform for s in segments]

        # === STAGE 1: LID (MMS + VoxLingua) ===
        mms_results = [{}] * n
        vox_results = [{}] * n

        if self.mms:
            t_lid = time.time()
            mms_results = self.mms.predict_batch(waveforms)
            logger.info(f"[{video_id}] MMS LID: {n} segs in {time.time()-t_lid:.1f}s")

        if self.vox:
            t_vox = time.time()
            vox_results = self.vox.predict_batch(waveforms)
            logger.info(f"[{video_id}] VoxLingua: {n} segs in {time.time()-t_vox:.1f}s")

        # === STAGE 2: CTC scoring (Conformer + Wav2Vec) ===
        # Only for Indic segments (not English, which has no IndicConformer coverage)
        conformer_results = [{}] * n
        wav2vec_results = [{}] * n

        lang_codes = [s.gemini_lang for s in segments]
        ref_texts = [s.gemini_transcription for s in segments]

        # Filter to Indic-only for CTC
        indic_mask = [lang in CONFORMER_LANG_CODES for lang in lang_codes]
        indic_indices = [i for i, m in enumerate(indic_mask) if m]

        if self.conformer and indic_indices:
            t_conf = time.time()
            indic_wavs = [waveforms[i] for i in indic_indices]
            indic_langs = [lang_codes[i] for i in indic_indices]
            indic_refs = [ref_texts[i] for i in indic_indices]
            conf_batch = self.conformer.predict_batch(indic_wavs, indic_langs, indic_refs)
            for j, idx in enumerate(indic_indices):
                conformer_results[idx] = conf_batch[j]
            logger.info(
                f"[{video_id}] Conformer CTC: {len(indic_indices)} Indic segs in "
                f"{time.time()-t_conf:.1f}s"
            )

        # Wav2Vec per-language: filter to languages that have models
        if self.wav2vec:
            wav2vec_langs = self.wav2vec.available_languages
            wav2vec_mask = [lang in wav2vec_langs for lang in lang_codes]
            wav2vec_indices = [i for i, m in enumerate(wav2vec_mask) if m]

            if wav2vec_indices:
                t_w2v = time.time()
                w2v_wavs = [waveforms[i] for i in wav2vec_indices]
                w2v_langs = [lang_codes[i] for i in wav2vec_indices]
                w2v_refs = [ref_texts[i] for i in wav2vec_indices]
                w2v_batch = self.wav2vec.predict_batch(w2v_wavs, w2v_langs, w2v_refs)
                for j, idx in enumerate(wav2vec_indices):
                    wav2vec_results[idx] = w2v_batch[j]
                logger.info(
                    f"[{video_id}] Wav2Vec CTC: {len(wav2vec_indices)} segs in "
                    f"{time.time()-t_w2v:.1f}s"
                )

        # === STAGE 3: Assemble results ===
        import json
        results: list[SegmentResult] = []

        for i, seg in enumerate(segments):
            mr = mms_results[i]
            vr = vox_results[i]
            cr = conformer_results[i]
            wr = wav2vec_results[i]

            mms_iso1 = mr.get("mms_lang_iso1", "")
            vox_iso1 = vr.get("vox_lang_iso1", "")
            gemini_lang = seg.gemini_lang

            # LID consensus: how many of the 3 sources agree?
            langs = [gemini_lang, mms_iso1, vox_iso1]
            langs = [l for l in langs if l]  # filter empty
            if langs:
                from collections import Counter
                counts = Counter(langs)
                consensus_lang, agree_count = counts.most_common(1)[0]
                lid_consensus = agree_count >= 2
            else:
                consensus_lang, agree_count, lid_consensus = "", 0, False

            results.append(SegmentResult(
                video_id=video_id,
                segment_file=seg.segment_file,
                duration_s=round(seg.duration_s, 3),
                gemini_lang=gemini_lang,
                gemini_transcription=seg.gemini_transcription,
                gemini_tagged=seg.gemini_tagged,
                gemini_quality_score=seg.gemini_quality_score,
                speaker_info=json.dumps(seg.speaker_info) if seg.speaker_info else "",
                mms_lang_iso3=mr.get("mms_lang_iso3", ""),
                mms_lang_iso1=mms_iso1,
                mms_confidence=mr.get("mms_confidence", 0.0),
                mms_top3=json.dumps(mr.get("mms_top3", [])),
                vox_lang=vr.get("vox_lang", ""),
                vox_lang_iso1=vox_iso1,
                vox_confidence=vr.get("vox_confidence", 0.0),
                vox_top3=json.dumps(vr.get("vox_top3", [])),
                vox_speaker_embedding=vr.get("vox_speaker_embedding", b""),
                conformer_multi_transcription=cr.get("conformer_multi_transcription", ""),
                conformer_multi_ctc_raw=cr.get("conformer_multi_ctc_raw"),
                conformer_multi_ctc_normalized=cr.get("conformer_multi_ctc_normalized"),
                wav2vec_transcription=wr.get("wav2vec_transcription", ""),
                wav2vec_ctc_raw=wr.get("wav2vec_ctc_raw"),
                wav2vec_ctc_normalized=wr.get("wav2vec_ctc_normalized"),
                wav2vec_model_used=wr.get("wav2vec_model_used", ""),
                lid_consensus=lid_consensus,
                lid_agree_count=agree_count,
                consensus_lang=consensus_lang,
            ))

        elapsed = time.time() - t0
        segs_per_sec = n / elapsed if elapsed > 0 else 0
        logger.info(
            f"[{video_id}] Done: {n} segments in {elapsed:.1f}s "
            f"({segs_per_sec:.0f} segs/s)"
        )

        return results

    def _log_gpu_usage(self):
        if torch.cuda.is_available():
            allocated = torch.cuda.memory_allocated() / 1e9
            reserved = torch.cuda.memory_reserved() / 1e9
            logger.info(f"GPU VRAM: {allocated:.1f}GB allocated, {reserved:.1f}GB reserved")
