"""Hot GPU encoder: keeps XCodec2 warm and processes batched segments.

XCodec2-only mode for SFT data encoding. BiCodec removed entirely.
The encoder stays loaded for the entire worker lifetime — segments flow in,
tokens flow out.

Chunking rule (zero data loss) — IDENTICAL to pretraining pipeline:
  - Segments <= 6s: single chunk, pad to hop boundary.
  - Segments > 6s: overlapping 6s windows (stride 5.8s, overlap 0.2s).
    If the would-be last chunk is shorter than MIN_TAIL, it is absorbed
    into the previous chunk (extending it to 6-8.8s). The model handles
    variable lengths natively (Conv1d + relative position embeddings).
  - After encoding, tokens are stitched with a deterministic center-cut
    rule: drop HALF_OVERLAP tokens at each internal boundary.

Ref: XCodec2 training: min_audio_length=96000 (6s random crops).
     Inference: full-length audio, pad to multiple of 320.
     CodecEncoder hop = prod([2,2,4,4,5]) = 320, so 50 tokens/sec.
"""

from __future__ import annotations

import gc
import logging
import time
from dataclasses import dataclass, field
from pathlib import Path

import torch
import torch.nn.functional as F

from codecbench.codecs.base import TokenBatch
from codecbench.pipeline.config import CodecConfig
from codecbench.pipeline.vad import Segment

logger = logging.getLogger(__name__)

# ── Chunking / stitching constants ────────────────────────────────────
CODEC_HOP = 320
TOKENS_PER_SEC = 50   # 16000 / 320
OVERLAP_S = 0.2
OVERLAP_TOKENS = int(OVERLAP_S * TOKENS_PER_SEC)  # 10
HALF_OVERLAP = OVERLAP_TOKENS // 2                  # 5
MIN_TAIL_S = 3.0
STANDARD_CHUNK = 96_000  # 6s at 16kHz


@dataclass
class EncodedSegment:
    """Result of encoding one audio segment through XCodec2.

    For multi-chunk segments (>6s), tokens are stitched from overlapping
    windows with center-cut boundary trimming. The stored tokens are
    one contiguous sequence covering the full segment.
    """
    segment_idx: int
    start_s: float
    end_s: float
    duration_s: float
    xcodec2_tokens: torch.Tensor   # [1, T_tok]
    encode_time_ms: float = 0.0
    num_chunks: int = 1


class HotEncoder:
    """Keeps XCodec2 loaded and warm on GPU.

    Call load() once at startup, then encode_segments() for each batch.
    Model is never unloaded between processing units.
    """

    def __init__(self, cfg: CodecConfig, device: str = "cuda"):
        self._cfg = cfg
        self._device = device
        self._xcodec = None
        self._loaded = False
        self._total_encode_time = 0.0
        self._total_segments = 0

    def load(self) -> None:
        """Load XCodec2 and warm it up."""
        from codecbench.codecs.xcodec2_fast import FastXCodec2Codec

        logger.info("Loading XCodec2...")
        self._xcodec = FastXCodec2Codec(model_id=self._cfg.xcodec2_model_id)
        if self._cfg.xcodec2_custom_ckpt:
            self._load_custom_xcodec_ckpt(self._cfg.xcodec2_custom_ckpt)
        else:
            self._xcodec.load(device=self._device)
        self._xcodec._use_compile = True
        self._xcodec.warmup(batch_size=self._cfg.xcodec_batch_size)

        vram = torch.cuda.memory_allocated() / 1e6
        logger.info("XCodec2 loaded and warm. VRAM: %.0f MB", vram)
        self._loaded = True

    def _load_custom_xcodec_ckpt(self, ckpt_path: str) -> None:
        """Load a custom XCodec2 checkpoint (fine-tuned weights)."""
        from xcodec2.modeling_xcodec2 import XCodec2Model
        from codecbench.codecs.xcodec2_fast import (
            GPUMelExtractor, _apply_layer_truncation, _apply_sdpa_patch,
        )

        logger.info("Loading custom XCodec2 checkpoint: %s", ckpt_path)

        model = XCodec2Model.from_pretrained(self._cfg.xcodec2_model_id)
        ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
        state_dict = ckpt.get("state_dict", ckpt.get("model", ckpt))
        cleaned = {}
        for k, v in state_dict.items():
            k = k.replace("model.", "", 1) if k.startswith("model.") else k
            cleaned[k] = v
        missing, unexpected = model.load_state_dict(cleaned, strict=False)
        if missing:
            logger.warning("Custom ckpt missing %d keys: %s", len(missing), missing[:5])
        if unexpected:
            logger.warning("Custom ckpt has %d unexpected keys: %s", len(unexpected), unexpected[:5])

        model.eval().to(self._device)
        _apply_layer_truncation(model.semantic_model)
        self._xcodec._mel_extractor = GPUMelExtractor(model.feature_extractor, device=self._device)
        self._xcodec._mel_extractor.to(self._device)
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        _apply_sdpa_patch(model.semantic_model)
        self._xcodec._model = model
        self._xcodec._device = self._device
        logger.info("Custom XCodec2 checkpoint loaded with all optimizations")

    @property
    def loaded(self) -> bool:
        return self._loaded

    @property
    def avg_encode_ms(self) -> float:
        if self._total_segments == 0:
            return 0.0
        return self._total_encode_time / self._total_segments

    # ── Overlap chunking (absorb-remainder, zero data loss) ───────────

    @staticmethod
    def _pad_to_hop(wav: torch.Tensor) -> torch.Tensor:
        """Pad waveform to the next multiple of CODEC_HOP (320).

        Matches the original repo's padding convention for variable-length
        inference. For standard 6s chunks (96000), 96000 % 320 == 0 so this
        adds the unconditional +320 that encode_code uses.
        """
        T = wav.shape[-1]
        pad = CODEC_HOP - (T % CODEC_HOP)
        return F.pad(wav, (0, pad))

    def _split_to_overlap_chunks(
        self, segments: list[Segment],
    ) -> tuple[list[Segment], list[int], list[tuple[int, int, Segment]]]:
        """Split segments into overlapping chunks. Never discards audio.

        Rule:
          1. T <= chunk_samples (6s): single chunk, pad to hop boundary.
          2. T > chunk_samples: stride-based 6s windows. If the would-be
             last chunk has < MIN_TAIL_S of audio, drop it and extend the
             previous chunk to the segment end (making it 6-8.8s).

        Returns:
            flat_chunks:   padded chunks ready for the GPU
            valid_samples: actual audio samples per chunk (before padding)
            groups:        (start_idx, count, original_segment) per parent
        """
        chunk_samples = int(self._cfg.chunk_seconds * self._cfg.target_sr)
        overlap_samples = int(OVERLAP_S * self._cfg.target_sr)
        stride = chunk_samples - overlap_samples
        min_tail = int(MIN_TAIL_S * self._cfg.target_sr)
        sr = self._cfg.target_sr

        flat_chunks: list[Segment] = []
        valid_samples: list[int] = []
        groups: list[tuple[int, int, Segment]] = []

        for seg in segments:
            T = seg.audio.shape[-1]
            group_start = len(flat_chunks)

            # ── Single chunk (<=6s): pad to hop boundary ──
            if T <= chunk_samples:
                flat_chunks.append(Segment(
                    start_s=seg.start_s, end_s=seg.end_s,
                    audio=self._pad_to_hop(seg.audio),
                ))
                valid_samples.append(T)
                groups.append((group_start, 1, seg))
                continue

            # ── Multi-chunk: generate stride positions ──
            starts: list[int] = []
            pos = 0
            while pos + chunk_samples <= T:
                starts.append(pos)
                pos += stride

            # Remainder: audio from last chunk_end to T
            last_end = starts[-1] + chunk_samples if starts else 0
            remainder = T - last_end

            if remainder > 0:
                if remainder >= min_tail:
                    # Remainder is long enough to be its own chunk
                    starts.append(T - remainder)
                else:
                    # Absorb remainder: extend the last chunk to segment end.
                    # Remove the last start and replace with one that reaches T.
                    # The new last chunk runs from (prev_last_start) to T.
                    pass  # starts stays as-is, we'll extend the last chunk below

            # ── Build chunks from starts ──
            for ci, s in enumerate(starts):
                is_last = ci == len(starts) - 1

                if is_last:
                    # Last chunk extends to segment end (absorbs remainder)
                    end = T
                else:
                    end = s + chunk_samples

                actual = end - s
                chunk_wav = seg.audio[..., s:end]
                padded = self._pad_to_hop(chunk_wav)

                flat_chunks.append(Segment(
                    start_s=seg.start_s + s / sr,
                    end_s=seg.start_s + end / sr,
                    audio=padded,
                ))
                valid_samples.append(actual)

            groups.append((group_start, len(starts), seg))

        n_extended = sum(1 for vs in valid_samples if vs > chunk_samples)
        if len(flat_chunks) != len(segments):
            logger.debug(
                "Overlap split: %d segments -> %d chunks (%d standard, %d extended)",
                len(segments), len(flat_chunks),
                len(flat_chunks) - n_extended, n_extended,
            )
        return flat_chunks, valid_samples, groups

    # ── Token trimming + center-cut stitching ─────────────────────────

    def _stitch_group(
        self,
        encoded: list[EncodedSegment],
        valids: list[int],
        original_seg: Segment,
        seg_idx: int,
    ) -> EncodedSegment:
        """Center-cut stitch overlapping chunks into one contiguous segment.

        Rule (deterministic, identical to pretraining pipeline):
          First chunk:  keep tokens[ 0 : valid_tok - H ]
          Middle chunk: keep tokens[ H : valid_tok - H ]
          Last chunk:   keep tokens[ H : valid_tok     ]
          Single chunk: keep tokens[ 0 : valid_tok     ]

        Where H = HALF_OVERLAP = 5 tokens (0.1s).
        """
        H = HALF_OVERLAP
        total_ms = sum(e.encode_time_ms for e in encoded)

        if len(encoded) == 1:
            ec = encoded[0]
            vt = valids[0] // CODEC_HOP
            return EncodedSegment(
                segment_idx=seg_idx,
                start_s=original_seg.start_s,
                end_s=original_seg.end_s,
                duration_s=original_seg.duration_s,
                xcodec2_tokens=ec.xcodec2_tokens[:, :vt],
                encode_time_ms=total_ms,
                num_chunks=1,
            )

        x_parts: list[torch.Tensor] = []
        n = len(encoded)

        for i, (ec, vs) in enumerate(zip(encoded, valids)):
            vt = vs // CODEC_HOP
            xtok = ec.xcodec2_tokens[:, :vt]

            if i == 0:
                x_parts.append(xtok[:, :max(vt - H, 1)])
            elif i == n - 1:
                x_parts.append(xtok[:, min(H, vt):])
            else:
                x_parts.append(xtok[:, min(H, vt):max(vt - H, H + 1)])

        return EncodedSegment(
            segment_idx=seg_idx,
            start_s=original_seg.start_s,
            end_s=original_seg.end_s,
            duration_s=original_seg.duration_s,
            xcodec2_tokens=torch.cat(x_parts, dim=1),
            encode_time_ms=total_ms,
            num_chunks=n,
        )

    # ── Flat-chunk encoding ───────────────────────────────────────────

    def _encode_one(self, seg: Segment) -> EncodedSegment | None:
        """Encode a single chunk (any length) through XCodec2."""
        t0 = time.perf_counter()
        wav = seg.audio.to(self._device)
        if wav.ndim == 2:
            wav = wav.unsqueeze(0)  # [1, 1, T]

        try:
            x_tb = self._xcodec.encode(wav, self._cfg.target_sr)
            torch.cuda.synchronize()
        except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
            logger.warning("CUDA error in _encode_one (seg %.1f-%.1fs): %s", seg.start_s, seg.end_s, str(e)[:120])
            torch.cuda.empty_cache()
            gc.collect()
            return None

        if x_tb is None:
            logger.warning("XCodec2 returned None for segment %.1f-%.1fs", seg.start_s, seg.end_s)
            return None

        elapsed_ms = (time.perf_counter() - t0) * 1000
        self._total_encode_time += elapsed_ms
        self._total_segments += 1

        try:
            return EncodedSegment(
                segment_idx=0,
                start_s=seg.start_s, end_s=seg.end_s, duration_s=seg.duration_s,
                xcodec2_tokens=x_tb.tokens[0:1].cpu(),
                encode_time_ms=elapsed_ms,
            )
        except (AttributeError, IndexError, KeyError) as e:
            logger.warning("Failed to extract tokens from codec output: %s", e)
            return None

    def _encode_batch(
        self, chunks: list[Segment], xcodec_batch_size_override: int | None = None,
    ) -> list[EncodedSegment]:
        """Encode a batch of same-length chunks (standard 6s path). XCodec2 only."""
        results = []
        xbs = max(xcodec_batch_size_override or self._cfg.xcodec_batch_size, 1)

        for i in range(0, len(chunks), xbs):
            batch = chunks[i : i + xbs]
            t0 = time.perf_counter()

            x_wavs = []
            for seg in batch:
                wav = seg.audio.to(self._device)
                if wav.ndim == 2:
                    wav = wav.unsqueeze(0)
                x_wavs.append(wav)
            x_input = torch.cat(x_wavs, dim=0)

            x_tb = self._xcodec.encode(x_input, self._cfg.target_sr)
            torch.cuda.synchronize()

            elapsed_ms = (time.perf_counter() - t0) * 1000
            self._total_encode_time += elapsed_ms
            self._total_segments += len(batch)

            if x_tb is None:
                logger.warning("Batch encode returned None (batch_start=%d)", i)
                for j in range(len(batch)):
                    results.append(None)
                continue

            try:
                if x_tb.tokens.shape[0] < len(batch):
                    raise IndexError(
                        f"XCodec batch output too small: {x_tb.tokens.shape[0]} < {len(batch)}"
                    )

                for j, seg in enumerate(batch):
                    results.append(EncodedSegment(
                        segment_idx=i + j,
                        start_s=seg.start_s, end_s=seg.end_s,
                        duration_s=seg.duration_s,
                        xcodec2_tokens=x_tb.tokens[j:j + 1].cpu(),
                        encode_time_ms=elapsed_ms / len(batch),
                    ))
            except (AttributeError, IndexError, KeyError) as e:
                logger.warning("Failed extracting tokens from batch result: %s", e)
                for j in range(len(batch)):
                    results.append(None)

        return results

    # ── Public API ────────────────────────────────────────────────────

    def encode_segments(
        self,
        segments: list[Segment],
        xcodec_batch_size_override: int | None = None,
    ) -> list[EncodedSegment]:
        """Encode audio segments with XCodec2. Zero data loss.

        Flow (identical to pretraining pipeline):
          1. Split >6s segments into overlapping chunks (stride 5.8s).
             Short remainders are absorbed into the previous chunk.
          2. Encode standard chunks batched; extended chunks B=1.
          3. Trim tokens to actual audio length (no padding tokens stored).
          4. Stitch overlapping chunks per parent (center-cut rule).

        Returns one EncodedSegment per original segment.
        """
        if not segments:
            return []

        flat_chunks, valid_samples, groups = self._split_to_overlap_chunks(segments)

        std_indices: list[int] = []
        ext_indices: list[int] = []
        for ci, vs in enumerate(valid_samples):
            if flat_chunks[ci].audio.shape[-1] == self._pad_to_hop(
                torch.zeros(STANDARD_CHUNK)
            ).shape[-1]:
                std_indices.append(ci)
            else:
                ext_indices.append(ci)

        encoded_flat: list[EncodedSegment | None] = [None] * len(flat_chunks)

        if std_indices:
            std_chunks = [flat_chunks[i] for i in std_indices]
            try:
                std_results = self._encode_batch(
                    std_chunks, xcodec_batch_size_override=xcodec_batch_size_override
                )
            except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
                if "out of memory" in str(e).lower() or "illegal memory" in str(e).lower():
                    logger.warning(
                        "CUDA error on batch encode (chunks=%d), falling back to B=1",
                        len(std_chunks),
                    )
                    torch.cuda.empty_cache()
                    gc.collect()
                    std_results = [self._encode_one(c) for c in std_chunks]
                else:
                    raise
            for ci, enc in zip(std_indices, std_results):
                encoded_flat[ci] = enc

        for ci in ext_indices:
            encoded_flat[ci] = self._encode_one(flat_chunks[ci])

        failed_chunks = sum(1 for e in encoded_flat if e is None)
        if failed_chunks:
            logger.warning("%d/%d chunks failed to encode", failed_chunks, len(encoded_flat))
            if failed_chunks == len(encoded_flat):
                return []

        results: list[EncodedSegment] = []
        for seg_idx, (start, count, orig_seg) in enumerate(groups):
            group_enc = encoded_flat[start: start + count]
            group_vs = valid_samples[start: start + count]
            if any(e is None for e in group_enc):
                logger.warning("Skipping segment %d (%.1f-%.1fs): %d/%d chunks failed",
                              seg_idx, orig_seg.start_s, orig_seg.end_s,
                              sum(1 for e in group_enc if e is None), count)
                continue
            stitched = self._stitch_group(group_enc, group_vs, orig_seg, seg_idx)
            results.append(stitched)

        return results
