"""
Spark TTS Model Loader

Loads Spark TTS model with vLLM engine and validates speakers/emotions.
Migrated from Veena3/Orpheus to Spark TTS architecture.
"""

import os
import torch
from typing import Optional, Dict, Any
from transformers import AutoTokenizer
from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams

from veena3modal.core.constants import (
    INDIC_EMOTION_TAGS,
    INDIC_SPEAKERS,
    SPEAKER_MAP,
    DEFAULT_MODEL_PATH,
    DEFAULT_MAX_MODEL_LEN,
    VLLM_CONFIG,
)


class SparkTTSModel:
    """
    Spark TTS Model with vLLM inference engine and BiCodec audio tokenizer.
    
    Model: BayAreaBoys/spark_tts_4speaker
    Architecture: Qwen2ForCausalLM with BiCodec audio tokenization
    Speakers: 12 predefined (mapped to speaker_0-11 as per training)
    Emotions: 10 tags in [bracket] format
    
    CRITICAL REQUIREMENTS:
    - Load tokenizer with special Spark TTS tokens
    - Configure vLLM with optimizations (prefix caching, CUDA graphs)
    - Support speaker mapping (lipakshi → speaker_0, etc.)
    """
    
    def __init__(
        self,
        model_path: str = None,
        model_type: str = "indic_speakers",  # Default to indic_speakers for backward compatibility
        hf_token: str = None,
        **engine_kwargs
    ):
        """
        Initialize Spark TTS model with vLLM.
        
        Args:
            model_path: Path to model directory (local or HF). If None, uses env var or default.
            model_type: Model type (kept for compatibility, always uses indic_speakers)
            hf_token: HuggingFace token for private models
            **engine_kwargs: Additional vLLM engine arguments (override defaults from VLLM_CONFIG)
        """
        self.model_type = "indic_speakers"  # Spark TTS only supports speaker-based
        
        # Use provided path, or check environment variable, or use default
        if model_path is None:
            model_path = os.environ.get(
                'SPARK_TTS_MODEL_PATH',
                DEFAULT_MODEL_PATH
            )
        
        self.model_path = model_path
        self.hf_token = hf_token
        
        print(f"🚀 Initializing Spark TTS Model")
        print(f"📁 Model path: {model_path}")
        print("🗣️  Speakers: 12 predefined "
              "(lipakshi, vardan, reet, Nandini, krishna, anika, adarsh, "
              "Nilay, Aarvi, Asha, Bittu, Mira)")
        print(f"🎭 Emotions: 10 tags in [bracket] format")
        
        # Load tokenizer
        print(f"\n📝 Loading tokenizer from model...")
        tokenizer_kwargs = {
            'trust_remote_code': True,
        }
        
        # Add HF token for private repos
        if hf_token:
            tokenizer_kwargs['token'] = hf_token
        
        # Add emotion tags as special tokens
        special_tokens = {
            'additional_special_tokens': INDIC_EMOTION_TAGS
        }
        
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_path,
            **tokenizer_kwargs
        )
        
        # Add special tokens
        self.tokenizer.add_special_tokens(special_tokens)
        
        print(f"✅ Tokenizer loaded: {len(self.tokenizer)} tokens")
        
        # Validate speakers and emotions
        self._validate_speakers()
        self._validate_emotion_tags()
        
        # Initialize vLLM engine with Spark TTS optimizations
        print(f"\n🔧 Initializing vLLM engine with optimizations...")
        
        # Start with default VLLM_CONFIG and override with kwargs
        engine_kwargs_dict = VLLM_CONFIG.copy()
        engine_kwargs_dict.update({
            'model': model_path,
            'tokenizer': model_path,
        })
        
        # Add HF token for private models
        if hf_token:
            os.environ['HUGGING_FACE_HUB_TOKEN'] = hf_token
        
        # Override with any user-provided kwargs
        engine_kwargs_dict.update(engine_kwargs)
        
        # Print configuration
        print(f"  - dtype: {engine_kwargs_dict['dtype']}")
        print(f"  - max_model_len: {engine_kwargs_dict['max_model_len']}")
        print(f"  - gpu_memory_utilization: {engine_kwargs_dict['gpu_memory_utilization']}")
        print(f"  - enable_prefix_caching: {engine_kwargs_dict.get('enable_prefix_caching', False)}")
        print(f"  - enforce_eager (disable CUDA graphs): {engine_kwargs_dict.get('enforce_eager', True)}")
        
        engine_args = AsyncEngineArgs(**engine_kwargs_dict)
        self.engine = AsyncLLMEngine.from_engine_args(engine_args)
        
        print(f"✅ vLLM engine initialized successfully!")
        print(f"\n🎉 Spark TTS Model ready for inference!\n")
    
    def _validate_emotion_tags(self):
        """
        Validate that emotion tags are properly configured for Spark TTS.
        
        For Spark TTS, emotion tags use [bracket] format and should be
        added as special tokens to the tokenizer.
        """
        print(f"\n🔍 Validating {len(INDIC_EMOTION_TAGS)} emotion tags for Spark TTS...")
        
        # Check if emotion tags can be tokenized
        for tag in INDIC_EMOTION_TAGS:
            token_ids = self.tokenizer.encode(tag, add_special_tokens=False)
            # Spark TTS emotion tags may be multiple tokens, that's okay
            if not token_ids:
                print(f"⚠️  Warning: Emotion tag {tag} produced no tokens")
        
        print(f"✅ All {len(INDIC_EMOTION_TAGS)} emotion tags validated")
        
        # Print sample emotion tags
        print(f"\n📋 Sample emotion tags:")
        for i, tag in enumerate(INDIC_EMOTION_TAGS[:5]):
            token_ids = self.tokenizer.encode(tag, add_special_tokens=False)
            print(f"  {tag}: {len(token_ids)} tokens")
        if len(INDIC_EMOTION_TAGS) > 5:
            print(f"  ...")
    
    def _validate_speakers(self):
        """
        Validate that all speakers are properly mapped for Spark TTS.
        
        Spark TTS uses speaker tokens like <|speaker_0|> to <|speaker_7|>.
        User-facing names (lipakshi, vardan, etc.) are mapped to these IDs.
        """
        print(f"\n🔍 Validating {len(INDIC_SPEAKERS)} speakers for Spark TTS...")
        
        # Validate speaker mapping
        for speaker_name in INDIC_SPEAKERS:
            if speaker_name not in SPEAKER_MAP:
                print(f"⚠️  Warning: Speaker {speaker_name} not in SPEAKER_MAP")
                continue
            
            speaker_id = SPEAKER_MAP[speaker_name]
            speaker_token = f'<|speaker_{speaker_id}|>'
            token_ids = self.tokenizer.encode(speaker_token, add_special_tokens=False)
            print(f"  {speaker_name} → {speaker_token}: {len(token_ids)} tokens")
        
        print(f"✅ All {len(INDIC_SPEAKERS)} speakers validated with mapping")
    
    def get_model_type(self) -> str:
        """Get the model type (always 'indic_speakers' for Spark TTS)."""
        return self.model_type
    
    def get_tokenizer(self):
        """Get the tokenizer."""
        return self.tokenizer
    
    def get_engine(self):
        """Get the vLLM engine."""
        return self.engine

