"""
Audio Format Encoders for Veena3 TTS

Provides encoding utilities for different audio formats:
- WAV (Raw PCM with RIFF header)
- Opus (compressed, good for streaming)
- MP3 (universal compatibility)
- mu-law (telephony, 8kHz)
- FLAC (lossless compression)
"""

import struct
import subprocess
import io
from typing import Optional
from abc import ABC, abstractmethod


def create_wav_header(
    sample_rate: int = 24000,
    channels: int = 1,
    bits_per_sample: int = 16,
    data_size: int = 0
) -> bytes:
    """
    Create a WAV file header (RIFF/WAVE format).
    
    Args:
        sample_rate: Audio sample rate in Hz (default: 24000)
        channels: Number of audio channels (default: 1 = mono)
        bits_per_sample: Bits per sample (default: 16)
        data_size: Size of audio data in bytes (0 = unknown/streaming)
    
    Returns:
        44-byte RIFF/WAVE header
    
    Note:
        For streaming with unknown size, set data_size=0.
        Most audio players will read PCM data until EOF.
        
    Example:
        >>> header = create_wav_header(sample_rate=24000, data_size=48000)
        >>> len(header)
        44
        >>> header[:4]
        b'RIFF'
    """
    # Calculate derived values
    byte_rate = sample_rate * channels * bits_per_sample // 8
    block_align = channels * bits_per_sample // 8
    
    # Pack the header using struct
    # Format: Little-endian (<) with specific field sizes
    header = struct.pack(
        '<4sI4s4sIHHIIHH4sI',
        b'RIFF',          # ChunkID (4 bytes)
        36 + data_size,   # ChunkSize (4 bytes) - file size - 8
        b'WAVE',          # Format (4 bytes)
        b'fmt ',          # Subchunk1ID (4 bytes)
        16,               # Subchunk1Size (4 bytes) - PCM format
        1,                # AudioFormat (2 bytes) - 1 = PCM
        channels,         # NumChannels (2 bytes)
        sample_rate,      # SampleRate (4 bytes)
        byte_rate,        # ByteRate (4 bytes)
        block_align,      # BlockAlign (2 bytes)
        bits_per_sample,  # BitsPerSample (2 bytes)
        b'data',          # Subchunk2ID (4 bytes)
        data_size         # Subchunk2Size (4 bytes)
    )
    
    return header


class AudioEncoder(ABC):
    """
    Base interface for audio format encoders.
    
    All format encoders implement this interface for consistent usage.
    """
    
    @abstractmethod
    def encode(self, pcm_data: bytes, sample_rate: int) -> bytes:
        """
        Encode PCM audio data to target format.
        
        Args:
            pcm_data: Raw 16-bit PCM audio data
            sample_rate: Sample rate of input PCM
            
        Returns:
            Encoded audio in target format
        """
        pass
    
    @abstractmethod
    def get_content_type(self) -> str:
        """Get MIME type for HTTP Content-Type header."""
        pass
    
    @abstractmethod
    def supports_streaming(self) -> bool:
        """Check if this encoder supports streaming."""
        pass


class WAVEncoder(AudioEncoder):
    """
    WAV format encoder (raw PCM with RIFF header).
    
    This is the native format - just adds proper headers to PCM data.
    Supports streaming since it's uncompressed.
    """
    
    def __init__(self, sample_rate: int = 24000, channels: int = 1, bits_per_sample: int = 16):
        """
        Initialize WAV encoder.
        
        Args:
            sample_rate: Target sample rate
            channels: Number of channels (1 = mono, 2 = stereo)
            bits_per_sample: Bits per sample (typically 16)
        """
        self.sample_rate = sample_rate
        self.channels = channels
        self.bits_per_sample = bits_per_sample
    
    def encode(self, pcm_data: bytes, sample_rate: int) -> bytes:
        """
        Encode PCM data as WAV file.
        
        Args:
            pcm_data: Raw 16-bit PCM audio data
            sample_rate: Sample rate of input (ignored - uses encoder's rate)
            
        Returns:
            Complete WAV file (header + PCM data)
        """
        header = create_wav_header(
            sample_rate=self.sample_rate,
            channels=self.channels,
            bits_per_sample=self.bits_per_sample,
            data_size=len(pcm_data)
        )
        
        return header + pcm_data
    
    def get_content_type(self) -> str:
        """Get MIME type for WAV."""
        return "audio/wav"
    
    def supports_streaming(self) -> bool:
        """WAV supports streaming."""
        return True
    
    def create_streaming_header(self) -> bytes:
        """
        Create WAV header for streaming (unknown size).
        
        Returns:
            WAV header with data_size=0 for streaming
        """
        return create_wav_header(
            sample_rate=self.sample_rate,
            channels=self.channels,
            bits_per_sample=self.bits_per_sample,
            data_size=0  # Unknown size for streaming
        )


class OpusEncoder(AudioEncoder):
    """
    Opus format encoder using ffmpeg.
    
    Opus is optimized for internet streaming with low latency.
    Uses ffmpeg subprocess for encoding.
    """
    
    def __init__(self, sample_rate: int = 24000, bitrate: str = "48k"):
        """
        Initialize Opus encoder.
        
        Args:
            sample_rate: Input sample rate
            bitrate: Output bitrate (e.g., "48k", "64k")
        """
        self.sample_rate = sample_rate
        self.bitrate = bitrate
    
    def encode(self, pcm_data: bytes, sample_rate: int) -> bytes:
        """
        Encode PCM data as Opus file using ffmpeg.
        
        Args:
            pcm_data: Raw 16-bit PCM audio data
            sample_rate: Sample rate of input
            
        Returns:
            Encoded Opus file
            
        Raises:
            subprocess.CalledProcessError: If ffmpeg fails
        """
        # Run ffmpeg as subprocess
        # Input: raw 16-bit PCM from stdin
        # Output: Opus encoded to stdout
        process = subprocess.run(
            [
                'ffmpeg',
                '-f', 's16le',  # Input format: signed 16-bit little-endian
                '-ar', str(sample_rate),  # Input sample rate
                '-ac', '1',  # 1 channel (mono)
                '-i', 'pipe:0',  # Read from stdin
                '-c:a', 'libopus',  # Opus codec
                '-b:a', self.bitrate,  # Target bitrate
                '-f', 'opus',  # Output format
                'pipe:1'  # Write to stdout
            ],
            input=pcm_data,
            capture_output=True,
            check=True
        )
        
        return process.stdout
    
    def get_content_type(self) -> str:
        """Get MIME type for Opus."""
        return "audio/opus"
    
    def supports_streaming(self) -> bool:
        """Opus does not support chunk-by-chunk streaming (needs full file)."""
        return False


class MP3Encoder(AudioEncoder):
    """
    MP3 format encoder using ffmpeg.
    
    MP3 offers universal compatibility across devices.
    Uses ffmpeg with libmp3lame encoder.
    """
    
    def __init__(self, sample_rate: int = 24000, quality: int = 4):
        """
        Initialize MP3 encoder.
        
        Args:
            sample_rate: Input sample rate
            quality: VBR quality (0=best, 9=worst), default 4=good quality
        """
        self.sample_rate = sample_rate
        self.quality = quality
    
    def encode(self, pcm_data: bytes, sample_rate: int) -> bytes:
        """
        Encode PCM data as MP3 file using ffmpeg.
        
        Args:
            pcm_data: Raw 16-bit PCM audio data
            sample_rate: Sample rate of input
            
        Returns:
            Encoded MP3 file
            
        Raises:
            subprocess.CalledProcessError: If ffmpeg fails
        """
        process = subprocess.run(
            [
                'ffmpeg',
                '-f', 's16le',  # Input format
                '-ar', str(sample_rate),  # Input sample rate
                '-ac', '1',  # Mono
                '-i', 'pipe:0',  # stdin
                '-c:a', 'libmp3lame',  # MP3 encoder
                '-q:a', str(self.quality),  # VBR quality
                '-f', 'mp3',  # Output format
                'pipe:1'  # stdout
            ],
            input=pcm_data,
            capture_output=True,
            check=True
        )
        
        return process.stdout
    
    def get_content_type(self) -> str:
        """Get MIME type for MP3."""
        return "audio/mpeg"
    
    def supports_streaming(self) -> bool:
        """MP3 does not support chunk-by-chunk streaming."""
        return False


class MuLawEncoder(AudioEncoder):
    """
    mu-law format encoder for telephony applications.
    
    Encodes to 8kHz mu-law format, commonly used in telephony.
    Uses ffmpeg for resampling and mu-law encoding.
    """
    
    def __init__(self):
        """Initialize mu-law encoder (always 8kHz)."""
        self.target_sample_rate = 8000
    
    def encode(self, pcm_data: bytes, sample_rate: int) -> bytes:
        """
        Encode PCM data as mu-law file.
        
        Args:
            pcm_data: Raw 16-bit PCM audio data
            sample_rate: Sample rate of input
            
        Returns:
            Encoded mu-law WAV file (8kHz)
            
        Raises:
            subprocess.CalledProcessError: If ffmpeg fails
        """
        process = subprocess.run(
            [
                'ffmpeg',
                '-f', 's16le',  # Input format
                '-ar', str(sample_rate),  # Input sample rate
                '-ac', '1',  # Mono
                '-i', 'pipe:0',  # stdin
                '-ar', str(self.target_sample_rate),  # Resample to 8kHz
                '-acodec', 'pcm_mulaw',  # mu-law codec
                '-f', 'wav',  # WAV container with mu-law
                'pipe:1'  # stdout
            ],
            input=pcm_data,
            capture_output=True,
            check=True
        )
        
        return process.stdout
    
    def get_content_type(self) -> str:
        """Get MIME type for mu-law WAV."""
        return "audio/x-wav"
    
    def supports_streaming(self) -> bool:
        """mu-law WAV could support streaming but needs resampling first."""
        return False


class FLACEncoder(AudioEncoder):
    """
    FLAC format encoder for lossless compression.
    
    Provides lossless compression (typically 50-60% of original size).
    Uses ffmpeg for encoding.
    """
    
    def __init__(self, sample_rate: int = 24000, compression_level: int = 5):
        """
        Initialize FLAC encoder.
        
        Args:
            sample_rate: Input sample rate
            compression_level: Compression level (0=fast, 8=best), default 5
        """
        self.sample_rate = sample_rate
        self.compression_level = compression_level
    
    def encode(self, pcm_data: bytes, sample_rate: int) -> bytes:
        """
        Encode PCM data as FLAC file using ffmpeg.
        
        Args:
            pcm_data: Raw 16-bit PCM audio data
            sample_rate: Sample rate of input
            
        Returns:
            Encoded FLAC file
            
        Raises:
            subprocess.CalledProcessError: If ffmpeg fails
        """
        process = subprocess.run(
            [
                'ffmpeg',
                '-f', 's16le',  # Input format
                '-ar', str(sample_rate),  # Input sample rate
                '-ac', '1',  # Mono
                '-i', 'pipe:0',  # stdin
                '-c:a', 'flac',  # FLAC encoder
                '-compression_level', str(self.compression_level),  # Compression
                '-f', 'flac',  # Output format
                'pipe:1'  # stdout
            ],
            input=pcm_data,
            capture_output=True,
            check=True
        )
        
        return process.stdout
    
    def get_content_type(self) -> str:
        """Get MIME type for FLAC."""
        return "audio/flac"
    
    def supports_streaming(self) -> bool:
        """FLAC does not support chunk-by-chunk streaming."""
        return False

