"""
Spark TTS BiCodec Decoder

Decodes BiCodec audio tokens (semantic + global) to audio waveforms.
Replaces SNAC decoder for Spark TTS model.
"""

import torch
import numpy as np
import sys
from pathlib import Path
from typing import List, Optional, Tuple

# Ensure `sparktts` (vendored under external/) is importable in local/dev runs.
# In Modal, PYTHONPATH already includes `/root/external/sparktts`.
try:
    from sparktts.models.audio_tokenizer import BiCodecTokenizer
except ImportError:  # pragma: no cover - environment-dependent
    _this_file = Path(__file__).resolve()
    _sparktts_path = None
    for _p in [_this_file.parent] + list(_this_file.parents):
        _candidate = _p / "external" / "sparktts"
        if _candidate.is_dir():
            _sparktts_path = str(_candidate)
            break
    if _sparktts_path and _sparktts_path not in sys.path:
        sys.path.insert(0, _sparktts_path)
    from sparktts.models.audio_tokenizer import BiCodecTokenizer

from veena3modal.core.constants import (
    BICODEC_TOKENIZER_PATH,
    AUDIO_SAMPLE_RATE,
)


class BiCodecDecoder:
    """
    BiCodec Decoder for Spark TTS.
    
    Decodes semantic and global tokens to audio waveforms using BiCodec.
    This replaces the SNAC decoder used in the old Indic Orpheus model.
    
    Supports streaming with sliding window approach (like SNAC).
    """
    
    def __init__(
        self,
        device: str = "cuda",
        model_path: str = BICODEC_TOKENIZER_PATH,
        enable_batching: bool = False,  # For compatibility with streaming pipeline
    ):
        """
        Initialize BiCodec decoder.
        
        Args:
            device: Device for BiCodec model (cuda/cpu)
            model_path: Path to BiCodec model checkpoint
            enable_batching: Enable async batching for streaming (not used yet)
        """
        self.device = torch.device(device if torch.cuda.is_available() else "cpu")
        self.model_path = model_path
        self.enable_batching = enable_batching  # For streaming pipeline compatibility
        
        print(f"🎵 Loading BiCodec audio tokenizer from {model_path}...")
        print(f"   Device: {self.device}")
        
        # Initialize BiCodec tokenizer
        self.audio_tokenizer = BiCodecTokenizer(model_path, device=str(self.device))
        
        print(f"✅ BiCodec decoder initialized (sample rate: {AUDIO_SAMPLE_RATE}Hz)")
        if enable_batching:
            print(f"   ℹ️  Batching enabled (for streaming support)")
    
    def decode(
        self, 
        semantic_ids: List[int], 
        global_ids: List[int],
    ) -> Optional[np.ndarray]:
        """
        Decode BiCodec tokens to audio waveform.
        
        CRITICAL: BiCodec expects EXACTLY 32 global tokens!
        The decoder pools them via speaker_encoder → d_vector, then broadcasts
        via d_vector.unsqueeze(-1) across the time dimension.
        
        Args:
            semantic_ids: List of semantic token IDs (variable length)
            global_ids: List of global token IDs (MUST be exactly 32!)
        
        Returns:
            Audio waveform as numpy array (float32, 16kHz mono)
            Shape: (samples,)
            Returns None if decode fails
        """
        if not semantic_ids or not global_ids:
            print(f"⚠️  Empty token lists: semantic={len(semantic_ids)}, global={len(global_ids)}")
            return None
        
        # Validate global token count
        EXPECTED_GLOBAL_TOKENS = 32
        if len(global_ids) != EXPECTED_GLOBAL_TOKENS:
            print(f"❌ BiCodec requires EXACTLY {EXPECTED_GLOBAL_TOKENS} global tokens, got {len(global_ids)}")
            # Try to fix it
            if len(global_ids) > EXPECTED_GLOBAL_TOKENS:
                global_ids = global_ids[:EXPECTED_GLOBAL_TOKENS]
                print(f"   └─ Truncated to {EXPECTED_GLOBAL_TOKENS}")
            else:
                global_ids = global_ids + [0] * (EXPECTED_GLOBAL_TOKENS - len(global_ids))
                print(f"   └─ Padded to {EXPECTED_GLOBAL_TOKENS}")
        
        try:
            # Convert to tensors following official Spark-TTS implementation
            # global_token_ids shape: (1, 32) - ALWAYS 32!
            # pred_semantic_ids shape: (1, N) - variable length
            pred_semantic = torch.tensor(semantic_ids).long().unsqueeze(0).to(self.device)
            pred_global = torch.tensor(global_ids).long().unsqueeze(0).to(self.device)
            
            # Ensure audio tokenizer is on correct device
            self.audio_tokenizer.device = self.device
            self.audio_tokenizer.model.to(self.device)
            
            # Decode through BiCodec
            # BiCodecTokenizer.detokenize() expects:
            #   - global_tokens: (batch, global_dim) -> will unsqueeze(1) internally
            #   - semantic_tokens: (batch, latent_dim)
            with torch.inference_mode():
                wav_np = self.audio_tokenizer.detokenize(
                    pred_global,  # (1, seq) - correct shape!
                    pred_semantic  # (1, seq)
                )
            
            return wav_np
            
        except Exception as e:
            print(f"❌ BiCodec decode error: {e}")
            import traceback
            traceback.print_exc()
            return None
    
    def decode_to_bytes(
        self, 
        semantic_ids: List[int], 
        global_ids: List[int],
    ) -> Optional[bytes]:
        """
        Decode BiCodec tokens to audio bytes (int16 PCM).
        
        Args:
            semantic_ids: List of semantic token IDs
            global_ids: List of global token IDs
        
        Returns:
            Audio as bytes (int16 PCM, 16kHz mono)
            Returns None if decode fails
        """
        audio = self.decode(semantic_ids, global_ids)
        
        if audio is None:
            return None
        
        # Clip and convert float32 to int16 PCM
        audio = np.clip(audio, -1.0, 1.0)
        audio_int16 = (audio * 32767).astype(np.int16)
        
        return audio_int16.tobytes()
    
    def validate_tokens(self, semantic_ids: List[int], global_ids: List[int]) -> bool:
        """
        Validate BiCodec tokens before decoding.
        
        Args:
            semantic_ids: List of semantic token IDs
            global_ids: List of global token IDs
        
        Returns:
            True if valid, False otherwise
        """
        # Check minimum length
        if not semantic_ids:
            print(f"❌ No semantic tokens")
            return False
        
        if not global_ids:
            print(f"❌ No global tokens")
            return False
        
        # Basic sanity checks
        if len(semantic_ids) < 1:
            print(f"❌ Too few semantic tokens: {len(semantic_ids)}")
            return False
        
        if len(global_ids) < 1:
            print(f"❌ Too few global tokens: {len(global_ids)}")
            return False
        
        return True
    
    def get_audio_duration(self, semantic_ids: List[int]) -> float:
        """
        Estimate audio duration from semantic tokens.
        
        Args:
            semantic_ids: List of semantic token IDs
        
        Returns:
            Estimated duration in seconds
        """
        # BiCodec generates audio based on semantic tokens
        # Approximate duration calculation (may need tuning)
        # This is an estimate and actual duration depends on model output
        estimated_samples = len(semantic_ids) * 320  # Rough estimate
        return estimated_samples / AUDIO_SAMPLE_RATE
    
    def decode_streaming(
        self,
        semantic_ids: List[int],
        global_ids: List[int],
        use_sliding_window: bool = False,
        trim_warmup: bool = False,
    ) -> Optional[bytes]:
        """
        Decode BiCodec tokens with streaming support.
        
        CRITICAL: BiCodec decoder expects EXACTLY 32 global tokens always!
        The decoder internally broadcasts d_vector via d_vector.unsqueeze(-1).
        
        From sparktts/models/bicodec.py line 184-186:
            d_vector = self.speaker_encoder.detokenize(global_tokens)  # Expects 32 tokens
            x = self.prenet(z_q, d_vector)
            x = x + d_vector.unsqueeze(-1)  # Broadcasts across time automatically!
        
        Args:
            semantic_ids: List of semantic token IDs (variable length, 50 TPS)
            global_ids: List of global token IDs (MUST be exactly 32 tokens!)
            use_sliding_window: If True, return only middle samples
            trim_warmup: Legacy parameter (not used for BiCodec)
        
        Returns:
            Audio bytes (int16 PCM) or None if decode fails
        """
        # Minimum semantic tokens needed for stable decode
        MIN_SEMANTIC_TOKENS = 8
        
        if len(semantic_ids) < MIN_SEMANTIC_TOKENS:
            print(f"⚠️  Too few semantic tokens ({len(semantic_ids)}) for decode, need >= {MIN_SEMANTIC_TOKENS}")
            return None
        
        # CRITICAL: Ensure we have exactly 32 global tokens
        EXPECTED_GLOBAL_TOKENS = 32
        if len(global_ids) != EXPECTED_GLOBAL_TOKENS:
            print(f"⚠️  WARNING: Got {len(global_ids)} global tokens, expected {EXPECTED_GLOBAL_TOKENS}")
            if len(global_ids) > EXPECTED_GLOBAL_TOKENS:
                # Truncate to first 32
                global_ids = global_ids[:EXPECTED_GLOBAL_TOKENS]
                print(f"   └─ Truncated to {EXPECTED_GLOBAL_TOKENS} tokens")
            else:
                # Pad with zeros if we somehow have fewer
                global_ids = global_ids + [0] * (EXPECTED_GLOBAL_TOKENS - len(global_ids))
                print(f"   └─ Padded to {EXPECTED_GLOBAL_TOKENS} tokens")
        
        # Decode with EXACTLY 32 global tokens + variable semantic tokens
        # The decoder handles broadcasting internally!
        audio = self.decode(semantic_ids, global_ids)
        
        if audio is None:
            return None
        
        # If using sliding window, return only the middle portion
        if use_sliding_window and len(audio) > 4096:
            # Keep middle samples for overlap-add streaming
            # BiCodec produces ~320 samples per semantic token (16kHz / 50 TPS)
            total_samples = len(audio)
            keep_samples = min(4096, total_samples // 2)
            start = (total_samples - keep_samples) // 2
            end = start + keep_samples
            audio = audio[start:end]
        
        # Convert to bytes
        audio = np.clip(audio, -1.0, 1.0)
        audio_int16 = (audio * 32767).astype(np.int16)
        
        return audio_int16.tobytes()
    
    async def decode_single_async(
        self,
        semantic_ids: List[int],
        global_ids: List[int],
        trim_warmup: bool = False,
        use_sliding_window: bool = False,
    ) -> Optional[bytes]:
        """
        Async wrapper for streaming decode (for compatibility with streaming pipeline).
        
        Args:
            semantic_ids: List of semantic token IDs
            global_ids: List of global token IDs  
            trim_warmup: Legacy parameter (not used)
            use_sliding_window: Use sliding window mode
        
        Returns:
            Audio bytes (int16 PCM) or None
        """
        # BiCodec decode is fast enough to run synchronously
        return self.decode_streaming(
            semantic_ids=semantic_ids,
            global_ids=global_ids,
            use_sliding_window=use_sliding_window,
            trim_warmup=trim_warmup
        )


