"""
Emotion Tag Normalizer for Spark TTS

Converts old Orpheus emotion tags to Spark TTS bracket format.
From: <emotion> → To: [emotion]

Also provides backward compatibility for legacy formats.
"""

import re
from typing import Dict

from veena3modal.core.constants import LEGACY_EMOTION_MAP, INDIC_EMOTION_TAGS


# Spark TTS emotion mapping (various formats → [emotion])
# Supports old <emotion>, natural language, and variations
EMOTION_TO_TAG = {
    # Old format mappings (from constants)
    **{k.strip('<>'): v for k, v in LEGACY_EMOTION_MAP.items()},
    
    # Natural language variations
    'laughing': '[laughs]',
    'laugh': '[laughs]',
    'laughs harder': '[laughs harder]',
    'laugh harder': '[laughs harder]',
    'laugh_harder': '[laughs harder]',
    'sighs': '[sighs]',
    'sigh': '[sighs]',
    'giggles': '[giggle]',
    'giggle': '[giggle]',
    'angry': '[angry]',
    'excited': '[excited]',
    'whispers': '[whispers]',
    'whisper': '[whispers]',
    'screaming': '[screams]',
    'scream': '[screams]',
    'screams': '[screams]',
    'singing': '[sings]',
    'sing': '[sings]',
    'sings': '[sings]',
    'curious': '[curious]',
}


def normalize_emotion_tags(text: str) -> str:
    """
    Normalize emotion tags to Spark TTS bracket format.
    
    Conversions:
    - <laugh> → [laughs]
    - <sigh> → [sighs]
    - <giggle> → [giggle]
    - etc.
    
    Also handles natural language:
    - [laughing] → [laughs]
    - [singing] → [sings]
    
    Args:
        text: Text with potential emotion tags (old or new format)
    
    Returns:
        Text with normalized [emotion] tags for Spark TTS
    """
    # First, convert old <emotion> format to [emotion]
    angle_pattern = r'<([a-z_]+)>'
    
    def replace_angle_emotion(match):
        emotion_text = match.group(1).lower().strip()
        
        # Direct lookup in mapping
        if emotion_text in EMOTION_TO_TAG:
            return EMOTION_TO_TAG[emotion_text]
        
        # Check if already in LEGACY_EMOTION_MAP
        old_tag = f'<{emotion_text}>'
        if old_tag in LEGACY_EMOTION_MAP:
            return LEGACY_EMOTION_MAP[old_tag]
        
        # Unknown emotion - convert to bracket format
        print(f"⚠️  Unknown old emotion: <{emotion_text}> - converting to [{emotion_text}]")
        return f'[{emotion_text}]'
    
    text = re.sub(angle_pattern, replace_angle_emotion, text)
    
    # Then normalize any natural language [emotion] tags
    bracket_pattern = r'\[([^\]]+)\]'
    
    def replace_bracket_emotion(match):
        emotion_text = match.group(1).lower().strip()
        
        # If already a valid Spark TTS emotion, keep it
        bracket_tag = f'[{emotion_text}]'
        if bracket_tag in INDIC_EMOTION_TAGS:
            return bracket_tag
        
        # Direct lookup in mapping
        if emotion_text in EMOTION_TO_TAG:
            return EMOTION_TO_TAG[emotion_text]
        
        # Try without 's' (plural → singular)
        if emotion_text.endswith('s'):
            singular = emotion_text[:-1]
            if singular in EMOTION_TO_TAG:
                return EMOTION_TO_TAG[singular]
        
        # Try with 's' (singular → plural)
        plural_form = emotion_text + 's'
        if plural_form in EMOTION_TO_TAG:
            return EMOTION_TO_TAG[plural_form]
        
        # Unknown emotion - keep as-is in bracket format
        print(f"⚠️  Unknown emotion: [{emotion_text}] - keeping as-is")
        return bracket_tag
    
    # Replace all [emotion] with normalized [emotion]
    normalized_text = re.sub(bracket_pattern, replace_bracket_emotion, text)
    
    return normalized_text


def validate_normalized_text(text: str) -> Dict[str, any]:
    """
    Validate that text has proper emotion tag format for Spark TTS.
    
    Returns:
        Dict with validation results
    """
    # Find square bracket tags (should have valid Spark TTS emotions)
    square_brackets = re.findall(r'\[([^\]]+)\]', text)
    
    # Find angle bracket tags (should be none after normalization)
    angle_brackets = re.findall(r'<([^>]+)>', text)
    
    # Check which bracket tags are valid Spark TTS emotions
    valid_emotion_tags = set(INDIC_EMOTION_TAGS)
    valid_tags = [f'[{tag}]' for tag in square_brackets if f'[{tag}]' in valid_emotion_tags]
    invalid_tags = [f'[{tag}]' for tag in square_brackets if f'[{tag}]' not in valid_emotion_tags]
    
    return {
        'has_angle_brackets': len(angle_brackets) > 0,  # Should be none
        'angle_brackets': [f'<{tag}>' for tag in angle_brackets],
        'bracket_emotions': [f'[{tag}]' for tag in square_brackets],
        'valid_emotion_tags': valid_tags,
        'invalid_emotion_tags': invalid_tags,
        'is_valid': len(angle_brackets) == 0,  # No angle brackets, only square brackets
    }


# Example usage and tests
if __name__ == "__main__":
    test_cases = [
        "And of course, the so-called 'easy' hack didn't work at all [sighs].",
        "That's so silly [giggles].",
        "[angry] I cannot believe this happened for the third time.",
        "[whispers] Look over there.",
        "I'm so happy I could just [singing] la-la-la!",
        "And then he did it again [laughs harder]!",
        "Old format <laugh> text needs conversion.",
        "Old format <sigh> and <giggle> mixed.",
    ]
    
    print("="*80)
    print("Emotion Tag Normalization Tests (Spark TTS)")
    print("="*80)
    
    for text in test_cases:
        print(f"\nOriginal:   {text}")
        normalized = normalize_emotion_tags(text)
        print(f"Normalized: {normalized}")
        
        validation = validate_normalized_text(normalized)
        if validation['is_valid']:
            print(f"✅ Valid (emotion tags: {validation['valid_emotion_tags']})")
        else:
            print(f"⚠️  Has angle brackets: {validation['angle_brackets']}")

