"""
Spark TTS Indic Prompt Builder

Builds prompts for Spark TTS with Indic speaker system.
Format: <|task_controllable_tts|><|start_content|>...<|end_content|>...

Migrated from Orpheus to Spark TTS architecture.
"""

import re
from typing import List, Tuple

from veena3modal.core.constants import INDIC_SPEAKERS, INDIC_EMOTION_TAGS, SPEAKER_MAP
from veena3modal.processing.emotion_normalizer import normalize_emotion_tags


class IndicPromptBuilder:
    """
    Builds prompts in Spark TTS format with Indic speaker system.
    
    Format:
    <|task_controllable_tts|>
    <|start_content|>{text with [emotions]}<|end_content|>
    <|start_style_label|><|speaker_{id}|><|end_style_label|>
    <|start_global_token|>
    
    CRITICAL: This format MUST match Spark TTS training exactly.
    
    Model Details:
    - HuggingFace: BayAreaBoys/spark_tts_4speaker
    - Architecture: Qwen2ForCausalLM with BiCodec audio tokenizer
    - Languages: Telugu, Hindi, English, and more
    - Speakers: 12 predefined (case-sensitive!) - must match training map
    - Emotions: 10 tags in [bracket] format
    """
    
    def __init__(self, tokenizer, model=None):
        """
        Initialize Indic prompt builder for Spark TTS.
        
        Args:
            tokenizer: Transformers tokenizer
            model: Model instance (optional, kept for backward compatibility)
        """
        self.tokenizer = tokenizer
        self.model = model
    
    def build_prefix(
        self,
        speaker: str,
        text: str,
        validate: bool = True
    ) -> str:
        """
        Build Spark TTS format prompt with Indic speaker.
        
        Args:
            speaker: Speaker name (one of 12 predefined speakers, case-sensitive!)
                Examples: 'lipakshi', 'reet', 'Nandini', 'Nilay', 'vardan', 'anika', 'adarsh',
                'krishna', 'Aarvi', 'Asha', 'Bittu', 'Mira'
            text: Text to synthesize with inline emotion tags
                Examples:
                - "Hello! Welcome to the demo."
                - "[laughs] The results were amazing!"
                - "The results were amazing and then [giggle] we celebrated!"
                - "नमस्ते! [excited] आज का दिन बहुत अच्छा है।" (Hindi)
                - "[curious] మీరు ఎలా ఉన్నారు?" (Telugu)
                - "Hello <laugh> this works too!" (will be normalized to [laughs])
            validate: Check format correctness (recommended for first use)
        
        Returns:
            Formatted prompt string ready for tokenization
            Format: <|task_controllable_tts|><|start_content|>text<|end_content|>...
        
        Note: Emotion tags should be inline in text. The API will normalize:
        - <laugh> → [laughs] (old format)
        - laughing → [laughs] (natural language)
        """
        # Normalize emotion tags: <emotion> → [emotion] (legacy compatibility)
        text = normalize_emotion_tags(text)
        
        # Validate if requested
        if validate:
            self._validate_inputs(speaker, text)
        
        # Get speaker ID from mapping
        speaker_id = SPEAKER_MAP.get(speaker)
        if speaker_id is None:
            raise ValueError(f"Invalid speaker: {speaker}. Valid speakers: {list(SPEAKER_MAP.keys())}")
        
        # Build Spark TTS format prompt
        prompt = "".join([
            "<|task_controllable_tts|>",
            "<|start_content|>",
            text,
            "<|end_content|>",
            "<|start_style_label|>",
            f"<|speaker_{speaker_id}|>",
            "<|end_style_label|>",
            "<|start_global_token|>"
        ])
        
        return prompt
    
    def build_prefix_ids(
        self,
        speaker: str,
        text: str,
        validate: bool = True
    ) -> List[int]:
        """
        Build prefix as token IDs (for testing/debugging).
        
        Args:
            speaker: Speaker name
            text: Text to synthesize (with inline emotions)
            validate: Check format correctness
        
        Returns:
            List of token IDs for the prefix
        """
        prompt = self.build_prefix(speaker, text, validate)
        token_ids = self.tokenizer.encode(prompt, add_special_tokens=False)
        return token_ids
    
    def _validate_inputs(self, speaker: str, text: str):
        """
        Validate speaker and text format.
        
        Args:
            speaker: Speaker name
            text: Text to synthesize (with inline emotions)
        
        Raises:
            ValueError: If inputs are invalid
        """
        # Validate speaker (case-sensitive!)
        if speaker not in INDIC_SPEAKERS:
            raise ValueError(
                f"Invalid speaker '{speaker}'. Must be one of (case-sensitive): "
                f"{', '.join(INDIC_SPEAKERS)}"
            )
        
        # Check text isn't empty
        if not text.strip():
            raise ValueError("Text cannot be empty")
        
        # Validate inline emotion tags in text (after normalization)
        self._validate_emotion_tags(text)
    
    def _validate_emotion_tags(self, text: str):
        """
        Validate emotion tags in text are in the allowed set for Spark TTS.
        
        Warns if unknown emotion tags are found (won't break, but won't work).
        
        Args:
            text: Text with potential emotion tags
        """
        # Find all [emotion] tags in text (Spark TTS format)
        emotion_pattern = r'\[([a-z\s]+)\]'
        found_emotions = re.findall(emotion_pattern, text)
        
        if not found_emotions:
            return  # No emotion tags, nothing to validate
        
        # Check each found emotion
        unknown_tags = []
        for emotion in found_emotions:
            tag = f'[{emotion}]'
            if tag not in INDIC_EMOTION_TAGS:
                unknown_tags.append(tag)
        
        if unknown_tags:
            print(f"⚠️  Warning: Found {len(unknown_tags)} unknown emotion tag(s) for Spark TTS:")
            for tag in unknown_tags:
                print(f"    - {tag} (will be treated as regular text)")
            print(f"\n📋 Valid Spark TTS emotion tags:")
            for tag in INDIC_EMOTION_TAGS:
                print(f"    {tag}")
            print("\n")
    
    def extract_emotion_tags(self, text: str) -> List[Tuple[str, int]]:
        """
        Extract emotion tags and their positions from text.
        
        Args:
            text: Text with emotion tags
        
        Returns:
            List of (emotion_tag, position) tuples
        """
        emotion_pattern = r'\[([a-z\s]+)\]'
        matches = []
        
        for match in re.finditer(emotion_pattern, text):
            tag = match.group(0)
            if tag in INDIC_EMOTION_TAGS:
                matches.append((tag, match.start()))
        
        return matches
    
    def build_prefix_with_globals(
        self,
        speaker: str,
        text: str,
        global_ids: List[int],
        validate: bool = True
    ) -> str:
        """
        Build prompt with pre-generated global tokens for voice consistency in chunked generation.
        
        This is critical for multi-chunk text processing: the 32 global tokens encode
        speaker identity (voice DNA). By injecting the same global tokens from the first
        chunk into subsequent chunks, we ensure consistent voice across the entire text.
        
        Use Case:
        - Chunk 1: Use build_prefix() → model generates 32 global tokens + semantic tokens
        - Chunk 2+: Use build_prefix_with_globals() with captured global tokens
                    → model skips global generation, generates only semantic tokens
        
        Args:
            speaker: Speaker name (same as first chunk)
            text: Text chunk to synthesize
            global_ids: List of 32 global token IDs captured from first chunk
            validate: Check format correctness
        
        Returns:
            Formatted prompt with pre-filled global tokens
            
        Raises:
            ValueError: If global_ids doesn't contain exactly 32 tokens
        
        Thread Safety:
            This method is stateless and thread-safe. Global tokens are passed
            explicitly per-request, not stored in any shared state.
        """
        # Validate global tokens count
        EXPECTED_GLOBAL_COUNT = 32
        if len(global_ids) != EXPECTED_GLOBAL_COUNT:
            raise ValueError(
                f"Expected exactly {EXPECTED_GLOBAL_COUNT} global tokens, got {len(global_ids)}. "
                f"This likely indicates an issue with first chunk generation."
            )
        
        # Normalize emotion tags: <emotion> → [emotion] (legacy compatibility)
        text = normalize_emotion_tags(text)
        
        # Validate if requested
        if validate:
            self._validate_inputs(speaker, text)
        
        # Get speaker ID from mapping
        speaker_id = SPEAKER_MAP.get(speaker)
        if speaker_id is None:
            raise ValueError(f"Invalid speaker: {speaker}. Valid speakers: {list(SPEAKER_MAP.keys())}")
        
        # Build global tokens string: <|bicodec_global_0|><|bicodec_global_1|>...<|bicodec_global_31|>
        global_tokens_str = "".join([f"<|bicodec_global_{gid}|>" for gid in global_ids])
        
        # Build Spark TTS format prompt WITH pre-filled global tokens
        # The model will see these and continue with semantic token generation only
        prompt = "".join([
            "<|task_controllable_tts|>",
            "<|start_content|>",
            text,
            "<|end_content|>",
            "<|start_style_label|>",
            f"<|speaker_{speaker_id}|>",
            "<|end_style_label|>",
            "<|start_global_token|>",
            global_tokens_str,  # Pre-filled 32 global tokens
            "<|start_semantic_token|>"  # Signal to start semantic tokens directly
        ])
        
        return prompt
    
    @staticmethod
    def get_available_speakers() -> List[str]:
        """Get list of available speakers."""
        return INDIC_SPEAKERS.copy()
    
    @staticmethod
    def get_available_emotions() -> List[str]:
        """Get list of available emotions (without [ ])."""
        return [e.strip('[]') for e in INDIC_EMOTION_TAGS]

