"""
Long Text Processing Service

Handles text chunking and audio stitching for long text inputs.

Why This Approach:
- Text-based chunking: Preserves semantic meaning
- SNAC token sliding window: Separate layer for smooth audio
- Crossfade stitching: Seamless audio transitions between chunks
- Unlimited text: Can process any length text

Architecture:
1. Long text → IndicSentenceChunker → Text chunks
2. Each text chunk → TTS Model → Audio chunk  
3. All audio chunks → Crossfade stitching → Final audio
"""

import asyncio
import numpy as np
import struct
from typing import List, Optional, AsyncGenerator
import logging

from veena3modal.processing.text_chunker import IndicSentenceChunker, crossfade_audio
from veena3modal.core.streaming_pipeline import Veena3SlidingWindowPipeline
from veena3modal.core.pipeline import SparkTTSPipeline

logger = logging.getLogger(__name__)


def create_wav_header(data_size: int, sample_rate: int = 24000, channels: int = 1, bits_per_sample: int = 16) -> bytes:
    """
    Create WAV file header for PCM data.
    
    Args:
        data_size: Size of PCM data in bytes
        sample_rate: Sample rate (Hz)
        channels: Number of channels
        bits_per_sample: Bits per sample
    
    Returns:
        WAV header bytes (44 bytes)
    """
    byte_rate = sample_rate * channels * bits_per_sample // 8
    block_align = channels * bits_per_sample // 8
    
    header = struct.pack('<4sI4s', b'RIFF', 36 + data_size, b'WAVE')
    header += struct.pack('<4sIHHIIHH', b'fmt ', 16, 1, channels, sample_rate, byte_rate, block_align, bits_per_sample)
    header += struct.pack('<4sI', b'data', data_size)
    
    return header


class LongTextProcessor:
    """
    Process long texts by chunking, generating audio, and stitching.
    
    Features:
    - Intelligent text chunking (respects sentence boundaries, Indic languages)
    - Parallel audio generation (if desired)
    - Crossfade stitching for seamless audio
    - Memory efficient (processes chunks sequentially by default)
    
    Why text chunking, not token chunking:
    - Text maintains semantic meaning
    - Natural sentence boundaries = better prosody
    - Simpler implementation
    - Model handles tokenization internally
    - SNAC token sliding window is separate (already implemented)
    
    Max Input Length Reference (post-normalization):
    "इसके निर्माण में मुख्य वास्तुकार उस्ताद अहमद लाहौरी के नेतृत्व में लगभग 
    20,000 कारीगरों और शिल्पकारों ने दिन-रात मेहनत की थी। इमारत के सामने बना 
    'चारबाग' शैली का उद्यान और पानी की नहरें इसकी सुंदरता में चार चाँद लगा देते हैं।"
    = ~232 characters after normalization (numbers expanded to words)
    """
    
    # Maximum safe input length for model (characters)
    # Based on reference Hindi sentence (~232 chars after normalization)
    # Model struggles with longer inputs, causes token repetition
    MAX_MODEL_INPUT_LENGTH = 230  # ~232 chars is the reference max
    
    # Threshold for when to enable chunking (characters)
    # Trigger chunking when text approaches max safe length
    # Using 95% of max to leave some buffer
    CHUNKING_THRESHOLD = 220  # ~95% of MAX_MODEL_INPUT_LENGTH
    
    # Chunk size for long texts (characters)  
    # Use 75% of max safe length to ensure each chunk processes cleanly
    # This gives model headroom for complex multilingual content
    CHUNK_SIZE = 170  # ~75% of MAX_MODEL_INPUT_LENGTH
    
    # Crossfade duration (milliseconds)
    CROSSFADE_MS = 50  # 50ms smooth transition
    
    def __init__(
        self,
        pipeline: SparkTTSPipeline,
        streaming_pipeline: Optional[Veena3SlidingWindowPipeline] = None
    ):
        """
        Initialize long text processor.
        
        Args:
            pipeline: Non-streaming TTS pipeline
            streaming_pipeline: Optional streaming pipeline (for streaming mode)
        """
        self.pipeline = pipeline
        self.streaming_pipeline = streaming_pipeline
        self.chunker = IndicSentenceChunker(max_chunk_length=self.CHUNK_SIZE)
        
        logger.info(
            f"LongTextProcessor initialized",
            extra={
                "max_model_input_length": self.MAX_MODEL_INPUT_LENGTH,
                "chunking_threshold": self.CHUNKING_THRESHOLD,
                "chunk_size": self.CHUNK_SIZE,
                "crossfade_ms": self.CROSSFADE_MS,
            }
        )
    
    def should_chunk(self, text: str) -> bool:
        """Determine if text should be chunked."""
        return len(text) > self.CHUNKING_THRESHOLD
    
    def chunk_text(self, text: str) -> list:
        """
        Chunk text using the internal chunker.
        
        Delegates to self.chunker.chunk_text() for actual chunking logic.
        This method is used by _stream_chunked_text in tts_runtime.
        
        Args:
            text: Text to chunk
        
        Returns:
            List of text chunks
        """
        return self.chunker.chunk_text(text)
    
    async def generate_with_chunking(
        self,
        text: str,
        speaker: str,
        temperature: float = 0.4,
        top_k: int = 50,  # Added for Spark TTS
        top_p: float = 0.9,
        max_tokens: int = 4096,
        repetition_penalty: float = 1.05,  # Prevent token repetition
        seed: Optional[int] = None,
        sample_rate: int = 16000,  # BiCodec uses 16kHz, not 24kHz
    ) -> Optional[bytes]:
        """
        Generate audio for long text with chunking and stitching.
        
        Process:
        1. Chunk text at natural boundaries
        2. Generate audio for each chunk
        3. Stitch audio with crossfade
        
        Args:
            text: Long text to synthesize
            speaker: Speaker name (lipakshi, reet, etc.)
            temperature: Sampling temperature
            top_p: Nucleus sampling
            max_tokens: Max tokens per chunk
            seed: Random seed (same seed = same voice across chunks)
            sample_rate: Audio sample rate
        
        Returns:
            Stitched audio bytes (int16 PCM) or None if generation fails
        """
        # Detect language mix for logging
        lang_info = self.chunker.detect_language_mix(text)
        
        logger.info(
            f"Processing long text with chunking",
            extra={
                "text_length": len(text),
                "primary_language": lang_info['primary'],
                "lang_mix": lang_info,
                "speaker": speaker,
            }
        )
        
        # 1. Chunk text
        text_chunks = self.chunker.chunk_text(text)
        
        logger.info(
            f"Text chunked into {len(text_chunks)} chunks",
            extra={
                "num_chunks": len(text_chunks),
                "chunk_sizes": [len(chunk) for chunk in text_chunks],
            }
        )
        
        # 2. Generate audio for each chunk
        audio_chunks = []
        
        # Log summary at start
        logger.info(
            f"🎵 Generating {len(text_chunks)} chunks for {len(text)} chars",
            extra={
                "total_chunks": len(text_chunks),
                "total_chars": len(text),
                "speaker": speaker,
            }
        )
        
        for i, chunk in enumerate(text_chunks):
            # Calculate appropriate max_tokens for this chunk
            # BiCodec semantic tokens: ~50 tokens/second, ~5 words/second = ~10 tokens/word
            # Rough estimate: 5 chars/word, so ~2 tokens/char for audio generation
            # Add buffer for safety: multiply by 3 to handle variations
            chunk_max_tokens = min(
                int(len(chunk) * 3),  # ~2 tokens/char * 1.5 safety buffer = 3
                4096  # Cap at model's max context length
            )
            chunk_max_tokens = max(chunk_max_tokens, 100)  # Ensure minimum
            
            # Generate audio for this chunk (returns WAV format with header)
            audio_bytes = await self.pipeline.generate_speech_indic(
                speaker=speaker,
                text=chunk,
                temperature=temperature,
                top_k=top_k,  # Added for Spark TTS
                top_p=top_p,
                max_tokens=chunk_max_tokens,  # Use calculated limit, not global max_tokens
                repetition_penalty=repetition_penalty,  # Prevent token repetition
                seed=seed,  # Use same seed for consistent voice across chunks
            )
            
            if audio_bytes is None:
                logger.error(
                    f"❌ Failed to generate audio for chunk {i+1}/{len(text_chunks)}",
                    extra={
                        "chunk_index": i,
                        "chunk_length": len(chunk),
                        "chunk_text_preview": chunk[:150],
                        "total_chunks": len(text_chunks),
                    }
                )
                return None
            
            # Strip WAV header (44 bytes) to get raw PCM data for stitching
            pcm_data = audio_bytes[44:] if len(audio_bytes) > 44 else audio_bytes
            audio_chunks.append(pcm_data)
        
        # 3. Stitch audio with crossfade
        stitched_pcm = self.stitch_audio_chunks(
            audio_chunks,
            sample_rate=sample_rate,
            crossfade_ms=self.CROSSFADE_MS,
        )
        
        # 4. Add WAV header to stitched PCM data
        wav_header = create_wav_header(
            data_size=len(stitched_pcm),
            sample_rate=sample_rate,
            channels=1,
            bits_per_sample=16
        )
        final_audio = wav_header + stitched_pcm
        
        # Calculate total audio metrics
        total_audio_duration = len(stitched_pcm) / (sample_rate * 2)
        
        logger.info(
            f"✅ Generated {len(text_chunks)} chunks → {total_audio_duration:.1f}s audio",
            extra={
                "total_chunks": len(text_chunks),
                "audio_duration": total_audio_duration,
                "audio_size_mb": len(final_audio) / (1024 * 1024),
            }
        )
        
        return final_audio
    
    def stitch_audio_chunks(
        self,
        audio_chunks: List[bytes],
        sample_rate: int = 16000,  # BiCodec uses 16kHz, not 24kHz
        crossfade_ms: int = 50,
    ) -> bytes:
        """
        Stitch multiple audio chunks with crossfade transitions.
        
        Why crossfade:
        - Prevents pops/clicks at chunk boundaries
        - Smooth transitions between chunks
        - Professional audio quality
        
        Args:
            audio_chunks: List of audio bytes (int16 PCM)
            sample_rate: Sample rate (Hz)
            crossfade_ms: Crossfade duration (milliseconds)
        
        Returns:
            Stitched audio bytes
        """
        if not audio_chunks:
            return b''
        
        if len(audio_chunks) == 1:
            return audio_chunks[0]
        
        # Calculate crossfade samples
        crossfade_samples = int((crossfade_ms / 1000.0) * sample_rate)
        
        # Start with first chunk
        result = audio_chunks[0]
        
        # Stitch remaining chunks with crossfade
        for i in range(1, len(audio_chunks)):
            result = crossfade_audio(
                result,
                audio_chunks[i],
                crossfade_samples=crossfade_samples,
                sample_rate=sample_rate,
            )
            
            logger.debug(
                f"Crossfaded chunk {i+1}",
                extra={
                    "chunk_index": i,
                    "current_length_bytes": len(result),
                }
            )
        
        return result
    
    async def generate_with_chunking_streaming(
        self,
        text: str,
        speaker: str,
        temperature: float = 0.4,
        top_p: float = 0.9,
        max_tokens: int = 4096,
        seed: Optional[int] = None,
    ) -> AsyncGenerator[bytes, None]:
        """
        Generate audio for long text with streaming (experimental).
        
        Note: Streaming long texts is more complex:
        - Need to stream each chunk
        - Need to buffer for crossfade
        - Adds latency for stitching
        
        For now, recommend using non-streaming for long texts.
        
        Args:
            text: Long text to synthesize
            speaker: Speaker name
            temperature: Sampling temperature
            top_p: Nucleus sampling
            max_tokens: Max tokens per chunk
            repetition_penalty: Repetition penalty
            seed: Random seed
        
        Yields:
            Audio bytes (in chunks)
        """
        if self.streaming_pipeline is None:
            raise ValueError("Streaming pipeline not available")
        
        # Chunk text
        text_chunks = self.chunker.chunk_text(text)
        
        logger.info(
            f"Streaming long text with {len(text_chunks)} chunks",
            extra={"num_chunks": len(text_chunks)}
        )
        
        # For streaming, we process each chunk sequentially
        # Future optimization: overlap processing + streaming
        for i, chunk in enumerate(text_chunks):
            logger.info(f"Streaming chunk {i+1}/{len(text_chunks)}")
            
            # Stream this chunk
            async for audio_bytes in self.streaming_pipeline.generate_speech_stream_indic(
                speaker=speaker,
                text=chunk,
                temperature=temperature,
                top_p=top_p,
                max_tokens=max_tokens,
                seed=seed,
            ):
                yield audio_bytes
            
            # TODO: Add crossfade between chunks in streaming mode
            # This requires buffering the end of previous chunk
            # and beginning of next chunk


