"""
TTS runtime wrapper (container-scoped singleton).

In Modal, each container should load the model once and then handle many requests.
This wrapper owns:
- SparkTTSModel (vLLM engine)
- IndicPromptBuilder
- BiCodecDecoder
- SparkTTSPipeline + Veena3SlidingWindowPipeline
- optional SuperResolutionService

Imports framework-agnostic inference + processing code from `veena3modal/core`,
`veena3modal/processing`, and `veena3modal/audio`.
"""

from __future__ import annotations

import os
import sys
import time
import logging
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple

logger = logging.getLogger(__name__)


@dataclass
class TTSRuntime:
    """
    Holds long-lived, per-container inference objects.
    
    Thread Safety:
        This class is designed for async concurrency within a single container.
        The vLLM engine handles internal batching and scheduling.
        Do NOT share across processes without proper synchronization.
    """
    model: Any = None
    pipeline: Any = None
    streaming_pipeline: Any = None
    prompt_builder: Any = None
    bicodec_decoder: Any = None
    sr_service: Optional[Any] = None
    
    model_version: str = "not_loaded"
    is_loaded: bool = False
    load_time_ms: float = 0.0
    
    # Configuration
    model_path: str = ""
    bicodec_path: str = ""
    sr_checkpoint_dir: Optional[str] = None
    device: str = "cuda"


# Module-level singleton (per-container)
_runtime: Optional[TTSRuntime] = None


def get_runtime() -> Optional[TTSRuntime]:
    """Get the current TTS runtime singleton."""
    return _runtime


def is_initialized() -> bool:
    """Check if the runtime is initialized and ready."""
    return _runtime is not None and _runtime.is_loaded


def initialize_runtime(
    model_path: Optional[str] = None,
    bicodec_path: Optional[str] = None,
    sr_checkpoint_dir: Optional[str] = None,
    device: str = "cuda",
    hf_token: Optional[str] = None,
    gpu_memory_utilization: float = 0.85,
    enable_sr: bool = False,
) -> TTSRuntime:
    """
    Initialize the TTS runtime with all components.
    
    This should be called once per container (e.g., in Modal's @modal.enter).
    
    Args:
        model_path: Path to Spark TTS model (env: SPARK_TTS_MODEL_PATH)
        bicodec_path: Path to BiCodec model (env: BICODEC_MODEL_PATH, defaults to model_path)
        sr_checkpoint_dir: Path to super-resolution checkpoints (env: AP_BWE_CHECKPOINT_DIR)
        device: Device for inference (cuda/cpu)
        hf_token: HuggingFace token for private models (env: HF_TOKEN)
        gpu_memory_utilization: vLLM GPU memory fraction (default: 0.85)
        enable_sr: Enable super-resolution service
    
    Returns:
        Initialized TTSRuntime instance
    
    Raises:
        RuntimeError: If initialization fails
    """
    global _runtime
    
    start_time = time.time()
    
    # Resolve paths from env vars if not provided
    model_path = model_path or os.environ.get(
        'SPARK_TTS_MODEL_PATH',
        os.environ.get('MODEL_PATH', '/models/spark_tts_4speaker')
    )
    bicodec_path = bicodec_path or os.environ.get(
        'BICODEC_MODEL_PATH',
        model_path  # BiCodec is usually in the same directory
    )
    sr_checkpoint_dir = sr_checkpoint_dir or os.environ.get('AP_BWE_CHECKPOINT_DIR')
    hf_token = hf_token or os.environ.get('HF_TOKEN') or os.environ.get('HUGGING_FACE_HUB_TOKEN')
    
    logger.info(f"Initializing TTS runtime...")
    logger.info(f"  Model path: {model_path}")
    logger.info(f"  BiCodec path: {bicodec_path}")
    logger.info(f"  SR enabled: {enable_sr}, path: {sr_checkpoint_dir}")
    
    try:
        # Import inference components (vendored into veena3modal for Modal deployment)
        from veena3modal.core.model_loader import SparkTTSModel
        from veena3modal.core.pipeline import SparkTTSPipeline
        from veena3modal.core.bicodec_decoder import BiCodecDecoder
        from veena3modal.processing.prompt_builder import IndicPromptBuilder
        
        # Load SparkTTS model with vLLM engine
        logger.info("Loading SparkTTS model with vLLM...")
        model = SparkTTSModel(
            model_path=model_path,
            hf_token=hf_token,
            gpu_memory_utilization=gpu_memory_utilization,
        )
        
        # Initialize prompt builder
        logger.info("Initializing prompt builder...")
        prompt_builder = IndicPromptBuilder(
            tokenizer=model.tokenizer,
            model=model,
        )
        
        # Initialize BiCodec decoder
        logger.info("Initializing BiCodec decoder...")
        bicodec_decoder = BiCodecDecoder(
            device=device,
            model_path=bicodec_path,
        )
        
        # Initialize pipeline
        logger.info("Initializing TTS pipeline...")
        pipeline = SparkTTSPipeline(
            model=model,
            prompt_builder=prompt_builder,
            bicodec_decoder=bicodec_decoder,
        )
        
        # Initialize streaming pipeline (for M4)
        streaming_pipeline = None
        try:
            from veena3modal.core.streaming_pipeline import Veena3SlidingWindowPipeline
            logger.info("Initializing streaming pipeline...")
            # NOTE: Parameter is named 'snac_decoder' for legacy reasons, but works with BiCodecDecoder
            streaming_pipeline = Veena3SlidingWindowPipeline(
                model=model,
                prompt_builder=prompt_builder,
                snac_decoder=bicodec_decoder,  # BiCodecDecoder is interface-compatible
            )
        except ImportError as e:
            logger.warning(f"Streaming pipeline not available: {e}")
        
        # Initialize super-resolution (optional)
        sr_service = None
        if enable_sr and sr_checkpoint_dir:
            try:
                from veena3modal.core.super_resolution import SuperResolutionService
                logger.info(f"Initializing super-resolution from {sr_checkpoint_dir}...")
                sr_service = SuperResolutionService(checkpoint_dir=sr_checkpoint_dir)
                # Load the model explicitly
                if sr_service.load_model(device=device):
                    logger.info("✅ Super-resolution model loaded successfully")
                else:
                    logger.warning("Super-resolution model failed to load")
                    sr_service = None
            except Exception as e:
                logger.warning(f"Super-resolution not available: {e}")
        
        # Determine model version
        model_version = os.path.basename(model_path.rstrip('/'))
        if not model_version:
            model_version = "spark-tts"
        
        load_time = (time.time() - start_time) * 1000
        
        # Create runtime
        _runtime = TTSRuntime(
            model=model,
            pipeline=pipeline,
            streaming_pipeline=streaming_pipeline,
            prompt_builder=prompt_builder,
            bicodec_decoder=bicodec_decoder,
            sr_service=sr_service,
            model_version=model_version,
            is_loaded=True,
            load_time_ms=load_time,
            model_path=model_path,
            bicodec_path=bicodec_path,
            sr_checkpoint_dir=sr_checkpoint_dir,
            device=device,
        )
        
        logger.info(f"✅ TTS runtime initialized in {load_time:.0f}ms")
        logger.info(f"   Model version: {model_version}")
        
        # Update FastAPI app with model version
        try:
            from veena3modal.api.fastapi_app import set_model_version
            set_model_version(model_version)
        except ImportError:
            pass
        
        return _runtime
        
    except Exception as e:
        logger.error(f"❌ Failed to initialize TTS runtime: {e}")
        import traceback
        traceback.print_exc()
        raise RuntimeError(f"TTS runtime initialization failed: {e}") from e


async def generate_speech(
    text: str,
    speaker: str,
    temperature: float = 0.8,
    top_k: int = 50,
    top_p: float = 1.0,
    max_tokens: int = 4096,
    repetition_penalty: float = 1.05,
    seed: Optional[int] = None,
    output_sample_rate: str = "16khz",
) -> Tuple[Optional[bytes], Dict[str, Any]]:
    """
    Generate speech audio (non-streaming).
    
    Args:
        text: Text to synthesize (already normalized)
        speaker: Internal speaker name (resolved)
        temperature: Sampling temperature
        top_k: Top-k sampling
        top_p: Nucleus sampling
        max_tokens: Maximum tokens to generate
        repetition_penalty: Repetition penalty
        seed: Random seed for reproducibility
        output_sample_rate: "16khz" or "48khz" (triggers super-resolution)
    
    Returns:
        Tuple of (audio_bytes, metrics_dict)
        audio_bytes: WAV audio bytes (16kHz or 48kHz, 16-bit PCM) or None if failed
        metrics_dict: Dictionary with timing metrics
    
    Raises:
        RuntimeError: If runtime not initialized
    """
    if not is_initialized():
        raise RuntimeError("TTS runtime not initialized")
    
    runtime = get_runtime()
    metrics = {
        "ttfb_ms": 0,
        "generation_ms": 0,
        "tokens_generated": 0,
        "audio_duration_seconds": 0.0,
        "sr_applied": False,
        "output_sample_rate": 16000,
    }
    
    start_time = time.time()
    
    try:
        # Generate audio using pipeline (always at 16kHz)
        audio_bytes = await runtime.pipeline.generate_speech_indic(
            speaker=speaker,
            text=text,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            max_tokens=max_tokens,
            repetition_penalty=repetition_penalty,
            seed=seed,
        )
        
        generation_time = time.time() - start_time
        metrics["generation_ms"] = int(generation_time * 1000)
        metrics["ttfb_ms"] = metrics["generation_ms"]  # Non-streaming: TTFB = total time
        
        if audio_bytes:
            # Apply super-resolution if requested and available
            sample_rate = 16000
            logger.info(f"SR check: output_sample_rate={output_sample_rate!r}, "
                       f"sr_service={runtime.sr_service is not None}, "
                       f"is_loaded={runtime.sr_service.is_loaded if runtime.sr_service else False}")
            
            if output_sample_rate == "48khz" and runtime.sr_service and runtime.sr_service.is_loaded:
                try:
                    sr_start = time.time()
                    logger.info(f"Applying super-resolution to {len(audio_bytes)} bytes...")
                    audio_bytes = _apply_super_resolution(audio_bytes, runtime.sr_service)
                    sr_time = (time.time() - sr_start) * 1000
                    metrics["sr_ms"] = int(sr_time)
                    metrics["sr_applied"] = True
                    sample_rate = 48000
                    logger.info(f"✅ Super-resolution applied in {sr_time:.1f}ms, output={len(audio_bytes)} bytes")
                except Exception as e:
                    logger.warning(f"Super-resolution failed, returning 16kHz: {e}")
                    import traceback
                    traceback.print_exc()
            
            metrics["output_sample_rate"] = sample_rate
            
            # Calculate audio duration
            audio_duration = (len(audio_bytes) - 44) / (sample_rate * 2)  # -44 for WAV header
            metrics["audio_duration_seconds"] = max(0.0, audio_duration)
        
        return audio_bytes, metrics
        
    except Exception as e:
        logger.error(f"Speech generation failed: {e}")
        raise


def _apply_super_resolution(audio_bytes: bytes, sr_service) -> bytes:
    """
    Apply super-resolution to audio bytes.
    
    Args:
        audio_bytes: WAV audio at 16kHz
        sr_service: SuperResolutionService instance
    
    Returns:
        WAV audio at 48kHz
    """
    import numpy as np
    import struct
    import torch
    
    # Parse WAV header and extract PCM data
    if len(audio_bytes) < 44:
        raise ValueError("Invalid WAV data")
    
    # Extract PCM data (skip 44-byte WAV header)
    pcm_data = np.frombuffer(audio_bytes[44:], dtype=np.int16)
    
    # Convert to float32 for SR model
    audio_float = pcm_data.astype(np.float32) / 32768.0
    
    # Convert numpy to torch tensor for process_chunk
    audio_tensor = torch.from_numpy(audio_float)
    
    # Apply super-resolution using process_chunk (16kHz -> 48kHz)
    # process_chunk expects [batch, samples] or [samples] and returns [batch, samples]
    upsampled_tensor = sr_service.process_chunk(audio_tensor)
    
    # Convert back to numpy, squeeze batch dim if present
    upsampled = upsampled_tensor.squeeze().cpu().numpy()
    
    # Convert back to int16
    upsampled_int16 = np.clip(upsampled * 32768.0, -32768, 32767).astype(np.int16)
    
    # Create new WAV header for 48kHz
    sample_rate = 48000
    num_channels = 1
    bits_per_sample = 16
    byte_rate = sample_rate * num_channels * bits_per_sample // 8
    block_align = num_channels * bits_per_sample // 8
    data_size = len(upsampled_int16) * 2
    
    wav_header = struct.pack(
        '<4sI4s4sIHHIIHH4sI',
        b'RIFF',
        36 + data_size,
        b'WAVE',
        b'fmt ',
        16,  # fmt chunk size
        1,   # PCM format
        num_channels,
        sample_rate,
        byte_rate,
        block_align,
        bits_per_sample,
        b'data',
        data_size,
    )
    
    return wav_header + upsampled_int16.tobytes()


async def generate_speech_chunked(
    text: str,
    speaker: str,
    temperature: float = 0.8,
    top_k: int = 50,
    top_p: float = 1.0,
    max_tokens: int = 4096,
    repetition_penalty: float = 1.05,
    seed: Optional[int] = None,
    sample_rate: int = 16000,
    output_sample_rate: str = "16khz",
) -> Tuple[Optional[bytes], Dict[str, Any]]:
    """
    Generate speech with automatic text chunking for long inputs.
    
    Uses LongTextProcessor to split text and stitch audio.
    
    Args:
        text: Text to synthesize
        speaker: Internal speaker name
        ... (same as generate_speech)
        sample_rate: Internal generation sample rate (always 16000)
        output_sample_rate: "16khz" or "48khz" (triggers super-resolution)
    
    Returns:
        Tuple of (audio_bytes, metrics_dict)
    """
    if not is_initialized():
        raise RuntimeError("TTS runtime not initialized")
    
    runtime = get_runtime()
    metrics = {
        "ttfb_ms": 0,
        "generation_ms": 0,
        "chunks_processed": 0,
        "audio_duration_seconds": 0.0,
        "text_chunked": False,
        "sr_applied": False,
        "output_sample_rate": 16000,
    }
    
    start_time = time.time()
    
    try:
        # Import long text processor
        from veena3modal.processing.long_text_processor import LongTextProcessor
        
        long_processor = LongTextProcessor(pipeline=runtime.pipeline)
        
        # Check if chunking is needed
        if long_processor.should_chunk(text):
            metrics["text_chunked"] = True
            audio_bytes = await long_processor.generate_with_chunking(
                text=text,
                speaker=speaker,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
                max_tokens=max_tokens,
                repetition_penalty=repetition_penalty,
                seed=seed,
                sample_rate=sample_rate,
            )
        else:
            # Short text: direct generation
            audio_bytes = await runtime.pipeline.generate_speech_indic(
                speaker=speaker,
                text=text,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
                max_tokens=max_tokens,
                repetition_penalty=repetition_penalty,
                seed=seed,
            )
        
        generation_time = time.time() - start_time
        metrics["generation_ms"] = int(generation_time * 1000)
        metrics["ttfb_ms"] = metrics["generation_ms"]
        
        if audio_bytes:
            # Apply super-resolution if requested and available
            final_sample_rate = 16000
            if output_sample_rate == "48khz" and runtime.sr_service and runtime.sr_service.is_loaded:
                try:
                    sr_start = time.time()
                    logger.info(f"Applying super-resolution to chunked audio ({len(audio_bytes)} bytes)...")
                    audio_bytes = _apply_super_resolution(audio_bytes, runtime.sr_service)
                    sr_time = (time.time() - sr_start) * 1000
                    metrics["sr_ms"] = int(sr_time)
                    metrics["sr_applied"] = True
                    final_sample_rate = 48000
                    logger.info(f"✅ Super-resolution applied in {sr_time:.1f}ms, output={len(audio_bytes)} bytes")
                except Exception as e:
                    logger.warning(f"Super-resolution failed, returning 16kHz: {e}")
                    import traceback
                    traceback.print_exc()
            
            metrics["output_sample_rate"] = final_sample_rate
            audio_duration = (len(audio_bytes) - 44) / (final_sample_rate * 2)
            metrics["audio_duration_seconds"] = max(0.0, audio_duration)
        
        return audio_bytes, metrics
        
    except Exception as e:
        logger.error(f"Chunked speech generation failed: {e}")
        raise


from typing import AsyncGenerator


async def generate_speech_streaming(
    text: str,
    speaker: str,
    temperature: float = 0.8,
    top_k: int = 50,
    top_p: float = 1.0,
    max_tokens: int = 4096,
    repetition_penalty: float = 1.05,
    seed: Optional[int] = None,
    enable_chunking: bool = True,
) -> AsyncGenerator[Tuple[bytes, Dict[str, Any]], None]:
    """
    Generate speech audio with true streaming (yields chunks as they're generated).
    
    This is the core streaming implementation for M4.
    First yield includes WAV header + first PCM chunk.
    Subsequent yields are raw PCM chunks.
    
    Args:
        text: Text to synthesize (already normalized)
        speaker: Internal speaker name (resolved)
        temperature: Sampling temperature
        top_k: Top-k sampling
        top_p: Nucleus sampling
        max_tokens: Maximum tokens to generate
        repetition_penalty: Repetition penalty
        seed: Random seed for reproducibility
        enable_chunking: Enable text chunking for long inputs with voice consistency
    
    Yields:
        Tuple of (audio_bytes, metrics_dict)
        - First yield: WAV header (44 bytes) prepended to first PCM chunk
        - Subsequent yields: Raw PCM chunks (int16, 16kHz)
        - Final yield: metrics_dict has final timing info
    
    Raises:
        RuntimeError: If runtime not initialized or streaming pipeline unavailable
    """
    if not is_initialized():
        raise RuntimeError("TTS runtime not initialized")
    
    runtime = get_runtime()
    
    if runtime.streaming_pipeline is None:
        raise RuntimeError("Streaming pipeline not available")
    
    # Import audio utils for WAV header
    from veena3modal.audio.utils import create_wav_header
    
    # Metrics tracking
    metrics = {
        "ttfb_ms": 0,
        "chunks_sent": 0,
        "total_bytes": 0,
        "audio_duration_seconds": 0.0,
        "text_chunked": False,
    }
    
    start_time = time.time()
    first_chunk_time = None
    total_pcm_bytes = 0
    sample_rate = 16000  # BiCodec sample rate
    
    # Check if we need text chunking
    from veena3modal.processing.long_text_processor import LongTextProcessor
    long_processor = LongTextProcessor(pipeline=runtime.pipeline)
    needs_chunking = enable_chunking and long_processor.should_chunk(text)
    
    if needs_chunking:
        # Chunked streaming with voice consistency (global token caching)
        metrics["text_chunked"] = True
        async for audio_chunk, chunk_metrics in _stream_chunked_text(
            runtime, long_processor, text, speaker, temperature, top_k, top_p,
            max_tokens, repetition_penalty, seed, sample_rate
        ):
            # First chunk: prepend WAV header
            if first_chunk_time is None:
                first_chunk_time = time.time()
                metrics["ttfb_ms"] = int((first_chunk_time - start_time) * 1000)
                # Create streaming WAV header (size=0 for unknown length)
                wav_header = create_wav_header(sample_rate=sample_rate, data_size=0)
                audio_chunk = wav_header + audio_chunk
            
            total_pcm_bytes += len(audio_chunk) - (44 if metrics["chunks_sent"] == 0 else 0)
            metrics["chunks_sent"] += 1
            metrics["total_bytes"] = total_pcm_bytes + 44  # Include header
            metrics["audio_duration_seconds"] = total_pcm_bytes / (sample_rate * 2)
            
            yield audio_chunk, metrics
    else:
        # Simple streaming (no text chunking)
        async for audio_chunk in runtime.streaming_pipeline.generate_speech_stream_indic(
            speaker=speaker,
            text=text,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            max_tokens=max_tokens,
            repetition_penalty=repetition_penalty,
            seed=seed,
        ):
            # First chunk: prepend WAV header
            if first_chunk_time is None:
                first_chunk_time = time.time()
                metrics["ttfb_ms"] = int((first_chunk_time - start_time) * 1000)
                wav_header = create_wav_header(sample_rate=sample_rate, data_size=0)
                audio_chunk = wav_header + audio_chunk
            
            total_pcm_bytes += len(audio_chunk) - (44 if metrics["chunks_sent"] == 0 else 0)
            metrics["chunks_sent"] += 1
            metrics["total_bytes"] = total_pcm_bytes + 44
            metrics["audio_duration_seconds"] = total_pcm_bytes / (sample_rate * 2)
            
            yield audio_chunk, metrics


async def _stream_chunked_text(
    runtime: TTSRuntime,
    long_processor,
    text: str,
    speaker: str,
    temperature: float,
    top_k: int,
    top_p: float,
    max_tokens: int,
    repetition_penalty: float,
    seed: Optional[int],
    sample_rate: int,
) -> AsyncGenerator[Tuple[bytes, Dict[str, Any]], None]:
    """
    Internal helper: stream chunked text with voice consistency.
    
    Uses global token caching from first chunk to maintain voice across chunks.
    """
    from veena3modal.audio.crossfade import crossfade_bytes_int16
    
    # Chunk text using the long text processor's chunker
    chunks = long_processor.chunk_text(text)
    
    if not chunks:
        return
    
    captured_globals: Optional[List[int]] = None
    previous_chunk_tail: Optional[bytes] = None
    chunk_metrics = {"chunks_processed": 0}
    
    for i, chunk_text in enumerate(chunks):
        chunk_metrics["chunks_processed"] = i + 1
        
        if i == 0:
            # First chunk: capture global tokens for voice consistency
            async for audio_bytes, global_ids in runtime.streaming_pipeline.generate_speech_stream_indic_first_chunk(
                speaker=speaker,
                text=chunk_text,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
                max_tokens=max_tokens,
                repetition_penalty=repetition_penalty,
                seed=seed,
            ):
                if captured_globals is None and global_ids:
                    captured_globals = global_ids
                
                # Crossfade with previous tail (inter-chunk stitching)
                to_emit, previous_chunk_tail = crossfade_bytes_int16(
                    previous_chunk_tail,
                    audio_bytes,
                    sample_rate_hz=sample_rate,
                    crossfade_ms=50,
                )
                
                if to_emit:
                    yield to_emit, chunk_metrics
        else:
            # Continuation chunks: use captured globals for voice consistency
            if captured_globals is None:
                logger.warning(f"No captured globals for chunk {i+1}, using regular streaming")
                async for audio_bytes in runtime.streaming_pipeline.generate_speech_stream_indic(
                    speaker=speaker,
                    text=chunk_text,
                    temperature=temperature,
                    top_k=top_k,
                    top_p=top_p,
                    max_tokens=max_tokens,
                    repetition_penalty=repetition_penalty,
                    seed=seed,
                ):
                    to_emit, previous_chunk_tail = crossfade_bytes_int16(
                        previous_chunk_tail,
                        audio_bytes,
                        sample_rate_hz=sample_rate,
                        crossfade_ms=50,
                    )
                    if to_emit:
                        yield to_emit, chunk_metrics
            else:
                # Use continuation method with cached globals
                async for audio_bytes in runtime.streaming_pipeline.generate_speech_stream_indic_continuation(
                    speaker=speaker,
                    text=chunk_text,
                    global_ids=captured_globals,
                    temperature=temperature,
                    top_k=top_k,
                    top_p=top_p,
                    max_tokens=max_tokens,
                    repetition_penalty=repetition_penalty,
                    seed=seed,
                ):
                    to_emit, previous_chunk_tail = crossfade_bytes_int16(
                        previous_chunk_tail,
                        audio_bytes,
                        sample_rate_hz=sample_rate,
                        crossfade_ms=50,
                    )
                    if to_emit:
                        yield to_emit, chunk_metrics
    
    # Flush remaining tail
    if previous_chunk_tail:
        yield previous_chunk_tail, chunk_metrics
