"""
Spark TTS Generation Pipeline

End-to-end pipeline for TTS generation with BiCodec audio tokenizer.
Migrated from Veena3/Orpheus (SNAC) to Spark TTS (BiCodec).
"""

import asyncio
import logging
import re
from typing import Optional, List, Dict, Any, Tuple
from vllm import SamplingParams
import torch

logger = logging.getLogger(__name__)

from veena3modal.core.constants import (
    TRAINING_STOP_TOKEN_IDS,
    DEFAULT_TEMPERATURE,
    DEFAULT_TOP_K,
    DEFAULT_TOP_P,
    DEFAULT_MAX_TOKENS,
    DEFAULT_SEED,
    AUDIO_SAMPLE_RATE,
)


class SparkTTSPipeline:
    """
    End-to-end TTS pipeline for Spark TTS with BiCodec.
    
    Replaces SNAC-based token extraction with BiCodec regex-based extraction.
    """
    
    def __init__(
        self,
        model,
        prompt_builder,
        bicodec_decoder,
    ):
        """
        Initialize pipeline.
        
        Args:
            model: SparkTTS Model instance
            prompt_builder: IndicPromptBuilder instance (with Spark TTS format)
            bicodec_decoder: BiCodecDecoder instance
        """
        self.model = model
        self.prompt_builder = prompt_builder
        self.bicodec_decoder = bicodec_decoder
        
        print(f"🚀 SparkTTSPipeline initialized")
    
    async def generate_speech(
        self,
        speaker: str,
        text: str,
        temperature: float = DEFAULT_TEMPERATURE,
        top_k: int = DEFAULT_TOP_K,
        top_p: float = DEFAULT_TOP_P,
        max_tokens: int = DEFAULT_MAX_TOKENS,
        repetition_penalty: float = 1.05,
        seed: Optional[int] = None,
    ) -> Optional[bytes]:
        """
        Generate speech audio (non-streaming) using Spark TTS.
        
        NOTE: This method signature changed from description-based to speaker-based
        to align with Spark TTS architecture.
        
        Args:
            speaker: Speaker name (one of 12 predefined speakers)
            text: Text to synthesize (with optional [emotion] tags)
            temperature: Sampling temperature (default: 0.8 for Spark TTS)
            top_k: Top-k sampling (default: 50)
            top_p: Nucleus sampling (default: 1.0)
            max_tokens: Max BiCodec tokens to generate (default: 2048)
            seed: Random seed for reproducibility
        
        Returns:
            Audio bytes (int16 PCM WAV, 16kHz mono) or None if failed
        """
        # Build prompt using Spark TTS format
        prompt = self.prompt_builder.build_prefix(speaker, text)
        
        # Configure sampling for Spark TTS
        sampling_params = SamplingParams(
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            max_tokens=max_tokens,
            repetition_penalty=repetition_penalty,  # Prevent token repetition (e.g., token 484 loop)
            stop=TRAINING_STOP_TOKEN_IDS,  # "<|im_end|>"
            skip_special_tokens=False,  # Keep special tokens for BiCodec extraction
            seed=seed,
        )
        
        # Generate tokens
        import uuid
        request_id = f"req-{uuid.uuid4().hex[:12]}"
        
        # Use vLLM engine to generate
        results_generator = self.model.engine.generate(
            prompt=prompt,
            sampling_params=sampling_params,
            request_id=request_id,
        )
        
        # Collect all outputs
        final_output = None
        async for request_output in results_generator:
            final_output = request_output
        
        if final_output is None:
            print(f"❌ Generation failed: no output")
            return None
        
        # Extract generated text (contains BiCodec tokens)
        generated_text = final_output.outputs[0].text
        generated_ids = final_output.outputs[0].token_ids
        
        # Extract BiCodec tokens using regex
        semantic_ids, global_ids = self._extract_bicodec_tokens(generated_text)
        
        if not semantic_ids or not global_ids:
            logger.error(
                "❌ No BiCodec tokens found in output",
                extra={
                    "speaker": speaker,
                    "text_length": len(text),
                    "text_preview": text[:100],
                    "generated_tokens": len(generated_ids),
                    "semantic_tokens": len(semantic_ids),
                    "global_tokens": len(global_ids),
                    "generated_text_preview": generated_text[:500],
                }
            )
            return None
        
        # Validate tokens
        if not self.bicodec_decoder.validate_tokens(semantic_ids, global_ids):
            logger.error("❌ BiCodec token validation failed", extra={
                "speaker": speaker,
                "text_length": len(text),
                "semantic_tokens": len(semantic_ids),
                "global_tokens": len(global_ids),
            })
            return None
        
        # Decode to audio
        audio_bytes = self.bicodec_decoder.decode_to_bytes(semantic_ids, global_ids)
        
        if audio_bytes is None:
            logger.error("❌ BiCodec decode failed", extra={
                "speaker": speaker,
                "text_length": len(text),
            })
            return None
        
        # Add WAV header
        from veena3modal.audio.utils import add_wav_header
        wav_bytes = add_wav_header(audio_bytes, sample_rate=AUDIO_SAMPLE_RATE)
        
        return wav_bytes
    
    async def generate_speech_indic(
        self,
        speaker: str,
        text: str,
        temperature: float = DEFAULT_TEMPERATURE,
        top_k: int = DEFAULT_TOP_K,
        top_p: float = DEFAULT_TOP_P,
        max_tokens: int = DEFAULT_MAX_TOKENS,
        repetition_penalty: float = 1.05,
        seed: Optional[int] = None,
    ) -> Optional[bytes]:
        """
        Generate speech audio for Spark TTS (speaker-based).
        
        This method is kept for backward compatibility with existing API.
        It delegates to generate_speech() with the same implementation.
        
        Args:
            speaker: Speaker name (one of 12 predefined speakers)
            text: Text to synthesize with inline emotion tags
                Examples:
                - "Hello! Welcome."
                - "[laughs] Hello there!"
                - "Hello <laugh> this is fun!" (will be normalized)
                - "नमस्ते! [excited] आज का दिन बहुत अच्छा है।"
            temperature: Sampling temperature
            top_k: Top-k sampling
            top_p: Nucleus sampling
            max_tokens: Max BiCodec tokens to generate
            seed: Random seed for reproducibility
        
        Returns:
            Audio bytes (int16 PCM WAV, 16kHz mono) or None if failed
        """
        # Delegate to generate_speech (same implementation for Spark TTS)
        return await self.generate_speech(
            speaker=speaker,
            text=text,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            max_tokens=max_tokens,
            repetition_penalty=repetition_penalty,
            seed=seed,
        )
    
    def _extract_bicodec_tokens(self, generated_text: str) -> Tuple[List[int], List[int]]:
        """
        Extract BiCodec semantic and global tokens from generated text using regex.
        
        Spark TTS generates tokens in the format:
        - <|bicodec_semantic_{id}|>
        - <|bicodec_global_{id}|>
        
        Args:
            generated_text: Generated text containing BiCodec token markers
        
        Returns:
            Tuple of (semantic_ids, global_ids)
            Returns ([], []) if no tokens found
        """
        # Extract semantic tokens
        semantic_matches = re.findall(r"<\|bicodec_semantic_(\d+)\|>", generated_text)
        semantic_ids = [int(t) for t in semantic_matches]
        
        # Extract global tokens
        global_matches = re.findall(r"<\|bicodec_global_(\d+)\|>", generated_text)
        global_ids = [int(t) for t in global_matches]
        
        # Only log if there's an issue
        if not semantic_ids and not global_ids:
            logger.error(f"❌ No BiCodec tokens found! Generated text (first 1000 chars):\n{generated_text[:1000]}")
        
        return semantic_ids, global_ids

