"""
Spark TTS BiCodec Decoder

Decodes BiCodec audio tokens (semantic + global) to audio waveforms.
Replaces SNAC decoder for Spark TTS model.
"""

import torch
import numpy as np
import sys
from pathlib import Path
from typing import List, Optional, Tuple

# Ensure `sparktts` (vendored under external/) is importable in local/dev runs.
# In Modal, PYTHONPATH already includes `/root/external/sparktts`.
try:
    from sparktts.models.audio_tokenizer import BiCodecTokenizer
except ImportError:  # pragma: no cover - environment-dependent
    _this_file = Path(__file__).resolve()
    _sparktts_path = None
    for _p in [_this_file.parent] + list(_this_file.parents):
        _candidate = _p / "external" / "sparktts"
        if _candidate.is_dir():
            _sparktts_path = str(_candidate)
            break
    if _sparktts_path and _sparktts_path not in sys.path:
        sys.path.insert(0, _sparktts_path)
    from sparktts.models.audio_tokenizer import BiCodecTokenizer

from veena3modal.core.constants import (
    BICODEC_TOKENIZER_PATH,
    AUDIO_SAMPLE_RATE,
)


class BiCodecDecoder:
    """
    BiCodec Decoder for Spark TTS.
    
    Decodes semantic and global tokens to audio waveforms using BiCodec.
    This replaces the SNAC decoder used in the old Indic Orpheus model.
    
    Supports streaming with sliding window approach (like SNAC).
    """
    
    def __init__(
        self,
        device: str = "cuda",
        model_path: str = BICODEC_TOKENIZER_PATH,
        enable_batching: bool = False,  # For compatibility with streaming pipeline
    ):
        """
        Initialize BiCodec decoder.
        
        Args:
            device: Device for BiCodec model (cuda/cpu)
            model_path: Path to BiCodec model checkpoint
            enable_batching: Enable async batching for streaming (not used yet)
        """
        self.device = torch.device(device if torch.cuda.is_available() else "cpu")
        self.model_path = model_path
        self.enable_batching = enable_batching  # For streaming pipeline compatibility
        
        print(f"🎵 Loading BiCodec audio tokenizer from {model_path}...")
        print(f"   Device: {self.device}")
        
        # Initialize BiCodec tokenizer
        self.audio_tokenizer = BiCodecTokenizer(model_path, device=str(self.device))
        
        # OPTIMIZATION: Ensure model is on device once at init, not per-decode call
        self.audio_tokenizer.device = self.device
        self.audio_tokenizer.model.to(self.device)
        
        # NOTE: torch.compile disabled for BiCodec -- dynamic semantic sequence lengths
        # cause repeated recompilation (~120s each), worse than eager mode.
        # torch.compile works best with static shapes. If needed, pre-warm with
        # representative lengths or use mode="default" instead of "reduce-overhead".
        # The multi-engine approach (Tier 3) provides more impactful gains.
        
        print(f"✅ BiCodec decoder initialized (sample rate: {AUDIO_SAMPLE_RATE}Hz)")
        if enable_batching:
            print(f"   ℹ️  Batching enabled (for streaming support)")
    
    def decode(
        self, 
        semantic_ids: List[int], 
        global_ids: List[int],
    ) -> Optional[np.ndarray]:
        """
        Decode BiCodec tokens to audio waveform.
        
        CRITICAL: BiCodec expects EXACTLY 32 global tokens!
        The decoder pools them via speaker_encoder → d_vector, then broadcasts
        via d_vector.unsqueeze(-1) across the time dimension.
        
        Args:
            semantic_ids: List of semantic token IDs (variable length)
            global_ids: List of global token IDs (MUST be exactly 32!)
        
        Returns:
            Audio waveform as numpy array (float32, 16kHz mono)
            Shape: (samples,)
            Returns None if decode fails
        """
        if not semantic_ids or not global_ids:
            print(f"⚠️  Empty token lists: semantic={len(semantic_ids)}, global={len(global_ids)}")
            return None
        
        # Validate global token count
        EXPECTED_GLOBAL_TOKENS = 32
        if len(global_ids) != EXPECTED_GLOBAL_TOKENS:
            print(f"❌ BiCodec requires EXACTLY {EXPECTED_GLOBAL_TOKENS} global tokens, got {len(global_ids)}")
            # Try to fix it
            if len(global_ids) > EXPECTED_GLOBAL_TOKENS:
                global_ids = global_ids[:EXPECTED_GLOBAL_TOKENS]
                print(f"   └─ Truncated to {EXPECTED_GLOBAL_TOKENS}")
            else:
                global_ids = global_ids + [0] * (EXPECTED_GLOBAL_TOKENS - len(global_ids))
                print(f"   └─ Padded to {EXPECTED_GLOBAL_TOKENS}")
        
        try:
            # Convert to tensors following official Spark-TTS implementation
            # global_token_ids shape: (1, 32) - ALWAYS 32!
            # pred_semantic_ids shape: (1, N) - variable length
            pred_semantic = torch.tensor(semantic_ids).long().unsqueeze(0).to(self.device)
            pred_global = torch.tensor(global_ids).long().unsqueeze(0).to(self.device)
            
            # Device migration done once in __init__, not per-call
            # Decode through BiCodec
            # BiCodecTokenizer.detokenize() expects:
            #   - global_tokens: (batch, global_dim) -> will unsqueeze(1) internally
            #   - semantic_tokens: (batch, latent_dim)
            with torch.inference_mode():
                wav_np = self.audio_tokenizer.detokenize(
                    pred_global,  # (1, seq) - correct shape!
                    pred_semantic  # (1, seq)
                )
            
            return wav_np
            
        except Exception as e:
            print(f"❌ BiCodec decode error: {e}")
            import traceback
            traceback.print_exc()
            return None
    
    def decode_to_bytes(
        self, 
        semantic_ids: List[int], 
        global_ids: List[int],
    ) -> Optional[bytes]:
        """
        Decode BiCodec tokens to audio bytes (int16 PCM).
        
        Args:
            semantic_ids: List of semantic token IDs
            global_ids: List of global token IDs
        
        Returns:
            Audio as bytes (int16 PCM, 16kHz mono)
            Returns None if decode fails
        """
        audio = self.decode(semantic_ids, global_ids)
        
        if audio is None:
            return None
        
        # Clip and convert float32 to int16 PCM
        audio = np.clip(audio, -1.0, 1.0)
        audio_int16 = (audio * 32767).astype(np.int16)
        
        return audio_int16.tobytes()
    
    def validate_tokens(self, semantic_ids: List[int], global_ids: List[int]) -> bool:
        """
        Validate BiCodec tokens before decoding.
        
        Args:
            semantic_ids: List of semantic token IDs
            global_ids: List of global token IDs
        
        Returns:
            True if valid, False otherwise
        """
        # Check minimum length
        if not semantic_ids:
            print(f"❌ No semantic tokens")
            return False
        
        if not global_ids:
            print(f"❌ No global tokens")
            return False
        
        # Basic sanity checks
        if len(semantic_ids) < 1:
            print(f"❌ Too few semantic tokens: {len(semantic_ids)}")
            return False
        
        if len(global_ids) < 1:
            print(f"❌ Too few global tokens: {len(global_ids)}")
            return False
        
        return True
    
    def get_audio_duration(self, semantic_ids: List[int]) -> float:
        """
        Estimate audio duration from semantic tokens.
        
        Args:
            semantic_ids: List of semantic token IDs
        
        Returns:
            Estimated duration in seconds
        """
        # BiCodec generates audio based on semantic tokens
        # Approximate duration calculation (may need tuning)
        # This is an estimate and actual duration depends on model output
        estimated_samples = len(semantic_ids) * 320  # Rough estimate
        return estimated_samples / AUDIO_SAMPLE_RATE
    
    def decode_streaming(
        self,
        semantic_ids: List[int],
        global_ids: List[int],
        use_sliding_window: bool = False,
        trim_warmup: bool = False,
    ) -> Optional[bytes]:
        """
        Decode BiCodec tokens with streaming support.
        
        CRITICAL: BiCodec decoder expects EXACTLY 32 global tokens always!
        The decoder internally broadcasts d_vector via d_vector.unsqueeze(-1).
        
        From sparktts/models/bicodec.py line 184-186:
            d_vector = self.speaker_encoder.detokenize(global_tokens)  # Expects 32 tokens
            x = self.prenet(z_q, d_vector)
            x = x + d_vector.unsqueeze(-1)  # Broadcasts across time automatically!
        
        Args:
            semantic_ids: List of semantic token IDs (variable length, 50 TPS)
            global_ids: List of global token IDs (MUST be exactly 32 tokens!)
            use_sliding_window: If True, return only middle samples
            trim_warmup: Legacy parameter (not used for BiCodec)
        
        Returns:
            Audio bytes (int16 PCM) or None if decode fails
        """
        # Minimum semantic tokens needed for stable decode
        MIN_SEMANTIC_TOKENS = 8
        
        if len(semantic_ids) < MIN_SEMANTIC_TOKENS:
            print(f"⚠️  Too few semantic tokens ({len(semantic_ids)}) for decode, need >= {MIN_SEMANTIC_TOKENS}")
            return None
        
        # CRITICAL: Ensure we have exactly 32 global tokens
        EXPECTED_GLOBAL_TOKENS = 32
        if len(global_ids) != EXPECTED_GLOBAL_TOKENS:
            print(f"⚠️  WARNING: Got {len(global_ids)} global tokens, expected {EXPECTED_GLOBAL_TOKENS}")
            if len(global_ids) > EXPECTED_GLOBAL_TOKENS:
                # Truncate to first 32
                global_ids = global_ids[:EXPECTED_GLOBAL_TOKENS]
                print(f"   └─ Truncated to {EXPECTED_GLOBAL_TOKENS} tokens")
            else:
                # Pad with zeros if we somehow have fewer
                global_ids = global_ids + [0] * (EXPECTED_GLOBAL_TOKENS - len(global_ids))
                print(f"   └─ Padded to {EXPECTED_GLOBAL_TOKENS} tokens")
        
        # Decode with EXACTLY 32 global tokens + variable semantic tokens
        # The decoder handles broadcasting internally!
        audio = self.decode(semantic_ids, global_ids)
        
        if audio is None:
            return None
        
        # If using sliding window, return only the middle portion
        if use_sliding_window and len(audio) > 4096:
            # Keep middle samples for overlap-add streaming
            # BiCodec produces ~320 samples per semantic token (16kHz / 50 TPS)
            total_samples = len(audio)
            keep_samples = min(4096, total_samples // 2)
            start = (total_samples - keep_samples) // 2
            end = start + keep_samples
            audio = audio[start:end]
        
        # Convert to bytes
        audio = np.clip(audio, -1.0, 1.0)
        audio_int16 = (audio * 32767).astype(np.int16)
        
        return audio_int16.tobytes()
    
    async def decode_single_async(
        self,
        semantic_ids: List[int],
        global_ids: List[int],
        trim_warmup: bool = False,
        use_sliding_window: bool = False,
    ) -> Optional[bytes]:
        """
        Async wrapper for streaming decode - runs in executor to avoid blocking event loop.
        
        OPTIMIZATION: Previously ran synchronously, blocking all concurrent coroutines
        during GPU decode (~20-50ms). Now uses run_in_executor to unblock the event loop
        so other streams can progress while this decode runs on GPU.
        
        Args:
            semantic_ids: List of semantic token IDs
            global_ids: List of global token IDs  
            trim_warmup: Legacy parameter (not used)
            use_sliding_window: Use sliding window mode
        
        Returns:
            Audio bytes (int16 PCM) or None
        """
        import asyncio
        loop = asyncio.get_event_loop()
        return await loop.run_in_executor(
            None,  # Default ThreadPoolExecutor
            self.decode_streaming,
            semantic_ids,
            global_ids,
            use_sliding_window,
            trim_warmup,
        )


