"""
Text Chunking Utility for Long Inputs

Splits long texts into manageable chunks with smart boundary detection.
Supports English, Hindi, Telugu, and mixed-language content.

Why Text-Based Chunking?
- Preserves semantic meaning (sentences, paragraphs)
- Natural boundaries for better prosody
- Simpler than token-based (model handles tokenization internally)
- Works with SNAC token sliding window (separate layer)
"""

import re
from typing import List, Tuple


class TextChunker:
    """
    Smart text chunking for long inputs.
    
    Splits text at natural boundaries (paragraphs, sentences, commas, spaces)
    to avoid mid-word/mid-sentence splits.
    """
    
    def __init__(self, max_chunk_length: int = 1000):
        """
        Initialize text chunker.
        
        Args:
            max_chunk_length: Maximum characters per chunk (default: 1000)
        """
        self.max_chunk_length = max_chunk_length
    
    def chunk_text(self, text: str) -> List[str]:
        """
        Split text into chunks at natural boundaries.
        
        Priority:
        1. Paragraph breaks (\\n\\n)
        2. Sentence ends (. ! ?)
        3. Commas
        4. Spaces
        
        Args:
            text: Text to chunk
        
        Returns:
            List of text chunks
        """
        if len(text) <= self.max_chunk_length:
            return [text]
        
        chunks = []
        remaining = text
        
        while len(remaining) > self.max_chunk_length:
            # Find the best split point within max_chunk_length
            chunk = remaining[:self.max_chunk_length]
            split_point = self._find_split_point(chunk)
            
            if split_point == 0:
                # No good split point found, force split at max length
                split_point = self.max_chunk_length
            
            # Extract chunk
            chunks.append(remaining[:split_point].strip())
            remaining = remaining[split_point:].strip()
        
        # Add remaining text
        if remaining:
            chunks.append(remaining)
        
        return chunks
    
    def _find_split_point(self, text: str) -> int:
        """
        Find the best split point in text.
        
        Returns index of split point (0 if none found).
        
        Enhanced to handle:
        - Abbreviations (Dr., Mr., Mrs., etc.)
        - Decimals (3.14)
        - Ellipsis (...)
        """
        # 1. Try paragraph break
        para_match = re.search(r'\n\n', text)
        if para_match:
            return para_match.end()
        
        # 2. Try sentence end (. ! ?) with smarter detection
        # Avoid splitting on abbreviations
        sentence_pattern = r'[.!?](?!\.\.)(?![0-9])(?![A-Z][a-z]\.)\s+'
        sentence_matches = list(re.finditer(sentence_pattern, text))
        if sentence_matches:
            # Use the last sentence end within the chunk
            return sentence_matches[-1].end()
        
        # 3. Try comma
        comma_matches = list(re.finditer(r',\s+', text))
        if comma_matches:
            # Use the last comma within the chunk
            return comma_matches[-1].end()
        
        # 4. Try space
        space_match = text.rfind(' ')
        if space_match > 0:
            return space_match + 1
        
        # 5. No good split point found
        return 0
    
    def chunk_with_overlap(
        self,
        text: str,
        overlap_chars: int = 50
    ) -> List[Tuple[str, int, int]]:
        """
        Chunk text with overlap for crossfade stitching.
        
        Args:
            text: Text to chunk
            overlap_chars: Number of characters to overlap (default: 50)
        
        Returns:
            List of (chunk_text, start_idx, end_idx) tuples
        """
        if len(text) <= self.max_chunk_length:
            return [(text, 0, len(text))]
        
        chunks = []
        start_idx = 0
        
        while start_idx < len(text):
            # Calculate end index
            end_idx = min(start_idx + self.max_chunk_length, len(text))
            
            # Extract chunk
            chunk_text = text[start_idx:end_idx]
            
            # Find split point if not at end
            if end_idx < len(text):
                split_point = self._find_split_point(chunk_text)
                if split_point > 0:
                    end_idx = start_idx + split_point
                    chunk_text = text[start_idx:end_idx]
            
            chunks.append((chunk_text.strip(), start_idx, end_idx))
            
            # Move start with overlap
            start_idx = end_idx - overlap_chars if end_idx < len(text) else end_idx
        
        return chunks


class SentenceBoundaryChunker(TextChunker):
    """
    Enhanced chunker that prioritizes sentence boundaries.
    
    Better handles abbreviations, decimals, and edge cases.
    """
    
    def _find_split_point(self, text: str) -> int:
        """
        Find split point prioritizing sentence boundaries.
        
        Enhanced detection to avoid splitting on:
        - Common abbreviations (Dr., Mr., Mrs., Prof., etc.)
        - Decimal numbers (3.14)
        - Ellipsis (...)
        """
        # Enhanced sentence pattern with negative lookahead only
        # (lookbehind requires fixed-width)
        sentence_pattern = (
            r'[.!?]'  # Sentence terminator
            r'(?!\.\.)'  # Not part of ellipsis (...)
            r'(?![0-9])'  # Not before digit (decimals like 3.14)
            r'\s+'  # Followed by whitespace
        )
        
        # Find all potential sentence ends
        sentence_matches = list(re.finditer(sentence_pattern, text))
        
        # Filter out false positives (abbreviations)
        valid_splits = []
        for match in sentence_matches:
            # Check if preceded by common abbreviation
            pos = match.start()
            # Look back a few characters to check for abbreviations
            lookback = text[max(0, pos-5):pos+1]
            
            # Skip if it's a known abbreviation
            if any(abbr in lookback for abbr in ['Dr.', 'Mr.', 'Mrs.', 'Ms.', 'Prof.', 'Sr.', 'Jr.']):
                continue
            if 'e.g.' in lookback or 'i.e.' in lookback or 'etc.' in lookback:
                continue
                
            valid_splits.append(match)
        
        if valid_splits:
            return valid_splits[-1].end()
        
        # Fall back to parent implementation
        return super()._find_split_point(text)


class IndicSentenceChunker(TextChunker):
    """
    Enhanced chunker with Indic language support (Hindi, Telugu, multilingual).
    
    Handles:
    - Hindi: Devanagari danda (।), double danda (॥)
    - Telugu: Sentence terminators (same Unicode as Hindi)
    - English: Standard punctuation (. ! ?)
    - Mixed language content
    
    Why this matters:
    - Indic languages use different punctuation (। instead of .)
    - Better prosody when splitting at natural language boundaries
    - Preserves meaning across languages
    """
    
    def _find_split_point(self, text: str) -> int:
        """
        Find split point with Indic language awareness.
        
        Priority:
        1. Paragraph breaks (\n\n)
        2. Indic sentence markers (। ॥)
        3. English sentence markers (. ! ?)
        4. Commas
        5. Spaces
        """
        # 1. Try paragraph break
        para_match = re.search(r'\n\n', text)
        if para_match:
            return para_match.end()
        
        # 2. Try Indic sentence markers
        # Devanagari danda (।) - used in Hindi, Marathi, Sanskrit
        # Double danda (॥) - paragraph/section marker
        indic_sentence = re.search(r'[।॥]\s*', text)
        if indic_sentence:
            return indic_sentence.end()
        
        # 3. Try English sentence markers
        # Enhanced pattern to handle abbreviations
        english_sentence = re.search(
            r'[.!?]'  # Sentence terminator
            r'(?!\.\.)'  # Not ellipsis
            r'(?![0-9])'  # Not decimal
            r'\s+',  # Followed by whitespace
            text
        )
        if english_sentence:
            return english_sentence.end()
        
        # 4. Try comma
        comma_matches = list(re.finditer(r',\s+', text))
        if comma_matches:
            return comma_matches[-1].end()
        
        # 5. Try space (last resort)
        space_match = text.rfind(' ')
        if space_match > 0:
            return space_match + 1
        
        return 0
    
    def detect_language_mix(self, text: str) -> dict:
        """
        Detect language composition in text.
        
        Returns:
            dict with language percentages and primary language
        """
        # Devanagari range (Hindi, Marathi, Sanskrit): U+0900 - U+097F
        devanagari_chars = len(re.findall(r'[\u0900-\u097F]', text))
        
        # Telugu range: U+0C00 - U+0C7F
        telugu_chars = len(re.findall(r'[\u0C00-\u0C7F]', text))
        
        # English (ASCII letters)
        english_chars = len(re.findall(r'[a-zA-Z]', text))
        
        total_chars = devanagari_chars + telugu_chars + english_chars
        
        if total_chars == 0:
            return {'primary': 'unknown', 'hindi': 0, 'telugu': 0, 'english': 0}
        
        hindi_pct = (devanagari_chars / total_chars) * 100
        telugu_pct = (telugu_chars / total_chars) * 100
        english_pct = (english_chars / total_chars) * 100
        
        # Determine primary language
        if hindi_pct > max(telugu_pct, english_pct):
            primary = 'hindi'
        elif telugu_pct > max(hindi_pct, english_pct):
            primary = 'telugu'
        elif english_pct > max(hindi_pct, telugu_pct):
            primary = 'english'
        else:
            primary = 'mixed'
        
        return {
            'primary': primary,
            'hindi': round(hindi_pct, 1),
            'telugu': round(telugu_pct, 1),
            'english': round(english_pct, 1),
        }


class ParagraphChunker(TextChunker):
    """
    Chunker that splits primarily on paragraph boundaries.
    
    Useful for document-style text with clear paragraph structure.
    """
    
    def chunk_text(self, text: str) -> List[str]:
        """Split text on paragraph boundaries."""
        # If text is short enough, return as-is
        if len(text) <= self.max_chunk_length:
            return [text]
        
        # Split on double newlines
        paragraphs = re.split(r'\n\n+', text)
        
        chunks = []
        current_chunk = ""
        
        for para in paragraphs:
            para = para.strip()
            if not para:
                continue
            
            # Check if adding this paragraph exceeds limit
            test_chunk = current_chunk + "\n\n" + para if current_chunk else para
            
            if len(test_chunk) <= self.max_chunk_length:
                current_chunk = test_chunk
            else:
                # Current chunk is ready, save it
                if current_chunk:
                    chunks.append(current_chunk)
                
                # Check if single paragraph is too long
                if len(para) > self.max_chunk_length:
                    # Use parent chunker to split this paragraph
                    para_chunks = super().chunk_text(para)
                    chunks.extend(para_chunks)
                    current_chunk = ""
                else:
                    current_chunk = para
        
        # Add remaining chunk
        if current_chunk:
            chunks.append(current_chunk)
        
        return chunks if chunks else [text]


def crossfade_audio(
    audio1: bytes,
    audio2: bytes,
    crossfade_samples: int = 1200,  # 50ms at 24kHz
    sample_rate: int = 24000
) -> bytes:
    """
    Crossfade two audio segments for seamless stitching.
    
    Uses linear crossfade with smooth amplitude transition to avoid
    pops, clicks, or audible artifacts at chunk boundaries.
    
    Args:
        audio1: First audio segment (int16 PCM)
        audio2: Second audio segment (int16 PCM)
        crossfade_samples: Number of samples to crossfade (default: 1200 = 50ms @ 24kHz)
        sample_rate: Sample rate (default: 24000 Hz)
    
    Returns:
        Stitched audio with crossfade
        
    Example:
        >>> audio1 = b'\\x00\\x01' * 24000  # 1 second
        >>> audio2 = b'\\x00\\x01' * 24000  # 1 second
        >>> result = crossfade_audio(audio1, audio2, crossfade_samples=1200)
        >>> len(result) < len(audio1) + len(audio2)  # Overlap reduces total length
        True
    """
    import numpy as np
    
    # Convert to numpy arrays
    arr1 = np.frombuffer(audio1, dtype=np.int16)
    arr2 = np.frombuffer(audio2, dtype=np.int16)
    
    # Handle edge case: zero crossfade samples
    if crossfade_samples == 0:
        return audio1 + audio2
    
    # Check if we have enough samples for crossfade
    if len(arr1) < crossfade_samples or len(arr2) < crossfade_samples:
        # Not enough samples, just concatenate
        return audio1 + audio2
    
    # Create linear fade curves
    # fade_out: 1.0 -> 0.0 (gradually reduce first audio)
    # fade_in: 0.0 -> 1.0 (gradually increase second audio)
    fade_out = np.linspace(1.0, 0.0, crossfade_samples)
    fade_in = np.linspace(0.0, 1.0, crossfade_samples)
    
    # Extract crossfade regions
    tail = arr1[-crossfade_samples:]  # Last N samples of audio1
    head = arr2[:crossfade_samples]    # First N samples of audio2
    
    # Apply crossfade (weighted sum)
    crossfaded = (tail * fade_out + head * fade_in).astype(np.int16)
    
    # Stitch: audio1[:-crossfade] + crossfaded + audio2[crossfade:]
    result = np.concatenate([
        arr1[:-crossfade_samples],  # Keep all but last N samples
        crossfaded,                  # The blended region
        arr2[crossfade_samples:]     # Skip first N samples (already in crossfade)
    ])
    
    return result.tobytes()

