"""
Audio processing module for handling segment duration and chunking.
Provides controls for splitting long segments to meet the 10s hard limit.
"""
import os
import math
from pathlib import Path
from typing import List, Dict, Optional
from dataclasses import dataclass
from pydub import AudioSegment


@dataclass
class AudioChunk:
    """Represents an audio chunk ready for transcription."""
    file_path: str
    original_segment: str  # Original segment filename
    chunk_index: int       # Index within the segment (0 if not split)
    total_chunks: int      # Total chunks for this segment
    start_ms: int          # Start time in milliseconds within original
    end_ms: int            # End time in milliseconds within original
    duration_sec: float    # Duration in seconds
    
    @property
    def is_split(self) -> bool:
        return self.total_chunks > 1


class AudioProcessor:
    """Handles audio segment processing with duration controls."""
    
    def __init__(
        self,
        max_duration_sec: float = 10.0,
        min_duration_sec: float = 1.0,
        overlap_sec: float = 0.0,
        output_format: str = "flac"
    ):
        """
        Initialize the audio processor.
        
        Args:
            max_duration_sec: Maximum allowed segment duration (hard limit)
            min_duration_sec: Minimum segment duration to process
            overlap_sec: Overlap between chunks when splitting (helps with word boundaries)
            output_format: Output format for split chunks (flac, wav, mp3)
        """
        self.max_duration_sec = max_duration_sec
        self.min_duration_sec = min_duration_sec
        self.overlap_sec = overlap_sec
        self.output_format = output_format
        
    def get_audio_duration(self, file_path: str) -> float:
        """Get duration of an audio file in seconds."""
        audio = AudioSegment.from_file(file_path)
        return len(audio) / 1000.0
    
    def needs_splitting(self, file_path: str) -> bool:
        """Check if a file needs to be split based on duration."""
        duration = self.get_audio_duration(file_path)
        return duration > self.max_duration_sec
    
    def split_audio(
        self,
        file_path: str,
        output_dir: str,
        chunk_duration_sec: Optional[float] = None
    ) -> List[AudioChunk]:
        """
        Split an audio file into chunks of specified duration.
        
        Args:
            file_path: Path to the audio file
            output_dir: Directory to save chunks
            chunk_duration_sec: Target chunk duration (defaults to max_duration_sec)
            
        Returns:
            List of AudioChunk objects
        """
        if chunk_duration_sec is None:
            chunk_duration_sec = self.max_duration_sec
            
        audio = AudioSegment.from_file(file_path)
        total_duration_ms = len(audio)
        total_duration_sec = total_duration_ms / 1000.0
        
        os.makedirs(output_dir, exist_ok=True)
        original_name = Path(file_path).stem
        
        # If already short enough, return as single chunk
        if total_duration_sec <= self.max_duration_sec:
            return [AudioChunk(
                file_path=file_path,
                original_segment=Path(file_path).name,
                chunk_index=0,
                total_chunks=1,
                start_ms=0,
                end_ms=total_duration_ms,
                duration_sec=total_duration_sec
            )]
        
        chunks = []
        chunk_duration_ms = int(chunk_duration_sec * 1000)
        overlap_ms = int(self.overlap_sec * 1000)
        
        # Calculate number of chunks
        effective_chunk_duration = chunk_duration_ms - overlap_ms
        num_chunks = math.ceil((total_duration_ms - overlap_ms) / effective_chunk_duration)
        
        for i in range(num_chunks):
            start_ms = i * effective_chunk_duration
            end_ms = min(start_ms + chunk_duration_ms, total_duration_ms)
            
            # Extract chunk
            chunk_audio = audio[start_ms:end_ms]
            
            # Skip very short chunks
            chunk_duration = len(chunk_audio) / 1000.0
            if chunk_duration < self.min_duration_sec and i > 0:
                # Extend previous chunk instead of creating tiny one
                continue
            
            # Save chunk
            chunk_filename = f"{original_name}_chunk{i:03d}.{self.output_format}"
            chunk_path = os.path.join(output_dir, chunk_filename)
            chunk_audio.export(chunk_path, format=self.output_format)
            
            chunks.append(AudioChunk(
                file_path=chunk_path,
                original_segment=Path(file_path).name,
                chunk_index=i,
                total_chunks=num_chunks,
                start_ms=start_ms,
                end_ms=end_ms,
                duration_sec=chunk_duration
            ))
        
        return chunks
    
    def process_segments_directory(
        self,
        segments_dir: str,
        output_dir: Optional[str] = None,
        max_segments: Optional[int] = None,
        skip_short: bool = True
    ) -> List[AudioChunk]:
        """
        Process all segments in a directory, splitting as needed.
        
        Args:
            segments_dir: Directory containing audio segments
            output_dir: Directory for split chunks (default: segments_dir/chunks)
            max_segments: Maximum number of segments to process (for testing)
            skip_short: Skip segments shorter than min_duration_sec
            
        Returns:
            List of AudioChunk objects ready for transcription
        """
        if output_dir is None:
            output_dir = os.path.join(segments_dir, "chunks")
        
        # Find all audio files
        audio_extensions = {'.flac', '.wav', '.mp3', '.ogg', '.m4a'}
        segment_files = []
        
        for ext in audio_extensions:
            segment_files.extend(Path(segments_dir).glob(f"*{ext}"))
        
        # Sort by filename for consistent ordering
        segment_files = sorted(segment_files, key=lambda x: x.name)
        
        if max_segments:
            segment_files = segment_files[:max_segments]
        
        print(f"[Audio] Processing {len(segment_files)} segments...")
        
        all_chunks = []
        split_count = 0
        skipped_count = 0
        
        for segment_path in segment_files:
            duration = self.get_audio_duration(str(segment_path))
            
            # Skip very short segments
            if skip_short and duration < self.min_duration_sec:
                skipped_count += 1
                continue
            
            # Split if needed
            if duration > self.max_duration_sec:
                chunks = self.split_audio(str(segment_path), output_dir)
                split_count += 1
                print(f"[Audio] Split {segment_path.name}: {duration:.1f}s -> {len(chunks)} chunks")
            else:
                # Use original file
                chunks = [AudioChunk(
                    file_path=str(segment_path),
                    original_segment=segment_path.name,
                    chunk_index=0,
                    total_chunks=1,
                    start_ms=0,
                    end_ms=int(duration * 1000),
                    duration_sec=duration
                )]
            
            all_chunks.extend(chunks)
        
        print(f"[Audio] Processed {len(segment_files)} segments -> {len(all_chunks)} chunks")
        print(f"[Audio] Split: {split_count}, Skipped (too short): {skipped_count}")
        
        return all_chunks


def get_segment_stats(segments_dir: str) -> Dict:
    """
    Get statistics about segments in a directory.
    
    Args:
        segments_dir: Directory containing audio segments
        
    Returns:
        Dictionary with duration statistics
    """
    processor = AudioProcessor()
    
    audio_extensions = {'.flac', '.wav', '.mp3', '.ogg', '.m4a'}
    durations = []
    
    for ext in audio_extensions:
        for f in Path(segments_dir).glob(f"*{ext}"):
            try:
                durations.append(processor.get_audio_duration(str(f)))
            except Exception as e:
                print(f"Error reading {f}: {e}")
    
    if not durations:
        return {"error": "No audio files found"}
    
    return {
        "total_files": len(durations),
        "total_duration_sec": sum(durations),
        "avg_duration_sec": sum(durations) / len(durations),
        "min_duration_sec": min(durations),
        "max_duration_sec": max(durations),
        "over_10s_count": sum(1 for d in durations if d > 10),
        "under_1s_count": sum(1 for d in durations if d < 1)
    }


if __name__ == "__main__":
    # Test with a sample directory
    import sys
    if len(sys.argv) > 1:
        segments_dir = sys.argv[1]
        stats = get_segment_stats(segments_dir)
        print("\nSegment Statistics:")
        for key, value in stats.items():
            print(f"  {key}: {value}")
