"""
Audio processing module for handling segment duration and chunking.

Splitting strategy: VAD-aware energy-based cutting.
When segments >max_duration_sec, finds the lowest-energy point in a window
around the target split time (8-13s) to avoid cutting mid-speech.
Sequential chunks track their original segment + cut points for later merging.
"""
import os
import math
import numpy as np
from pathlib import Path
from typing import List, Dict, Optional
from dataclasses import dataclass
from pydub import AudioSegment


@dataclass
class AudioChunk:
    """Represents an audio chunk ready for transcription."""
    file_path: str
    original_segment: str  # Original segment filename
    chunk_index: int       # Index within the segment (0 if not split)
    total_chunks: int      # Total chunks for this segment
    start_ms: int          # Start time in milliseconds within original
    end_ms: int            # End time in milliseconds within original
    duration_sec: float    # Duration in seconds
    
    @property
    def is_split(self) -> bool:
        return self.total_chunks > 1


class AudioProcessor:
    """Handles audio segment processing with duration controls."""
    
    def __init__(
        self,
        max_duration_sec: float = 10.0,
        min_duration_sec: float = 2.0,
        overlap_sec: float = 0.0,
        output_format: str = "flac"
    ):
        """
        Initialize the audio processor.
        
        Args:
            max_duration_sec: Maximum allowed segment duration (hard limit)
            min_duration_sec: Minimum segment duration to process
            overlap_sec: Overlap between chunks when splitting (helps with word boundaries)
            output_format: Output format for split chunks (flac, wav, mp3)
        """
        self.max_duration_sec = max_duration_sec
        self.min_duration_sec = min_duration_sec
        self.overlap_sec = overlap_sec
        self.output_format = output_format
        
    def get_audio_duration(self, file_path: str) -> float:
        """Get duration of an audio file in seconds."""
        audio = AudioSegment.from_file(file_path)
        return len(audio) / 1000.0
    
    def needs_splitting(self, file_path: str) -> bool:
        """Check if a file needs to be split based on duration."""
        duration = self.get_audio_duration(file_path)
        return duration > self.max_duration_sec
    
    def _find_energy_valley(self, samples: np.ndarray, sr: int,
                            target_ms: int, search_window_ms: int = 3000
                            ) -> int:
        """
        Find the lowest-energy point near target_ms for VAD-aware splitting.
        Searches +/- search_window_ms around target for a silence/low-energy valley.
        Returns the best cut point in milliseconds.
        """
        # RMS energy in 20ms windows
        frame_ms = 20
        frame_samples = int(sr * frame_ms / 1000)
        total_ms = int(len(samples) / sr * 1000)

        search_start = max(0, target_ms - search_window_ms)
        search_end = min(total_ms, target_ms + search_window_ms)

        best_energy = float('inf')
        best_ms = target_ms  # fallback to target if no valley found

        for t_ms in range(search_start, search_end, frame_ms):
            idx = int(t_ms * sr / 1000)
            frame = samples[idx:idx + frame_samples]
            if len(frame) < frame_samples // 2:
                continue
            rms = np.sqrt(np.mean(frame ** 2)) + 1e-10
            energy_db = 20 * np.log10(rms)
            if energy_db < best_energy:
                best_energy = energy_db
                best_ms = t_ms

        return best_ms

    def split_audio(
        self,
        file_path: str,
        output_dir: str,
        chunk_duration_sec: Optional[float] = None
    ) -> List[AudioChunk]:
        """
        Split audio using VAD-aware energy-based cutting.

        Instead of chopping at fixed intervals, finds the lowest-energy point
        in a window around the target split time to avoid cutting mid-speech.
        Target: 8-13s chunks. Searches +/- 3s around target for silence valleys.

        Sequential chunks track their original segment + cut points so they
        can be merged back later (transcripts appended, original audio retained).

        Args:
            file_path: Path to the audio file
            output_dir: Directory to save chunks
            chunk_duration_sec: Target chunk duration (defaults to max_duration_sec)

        Returns:
            List of AudioChunk objects with sequential chunk indices
        """
        if chunk_duration_sec is None:
            chunk_duration_sec = self.max_duration_sec

        audio = AudioSegment.from_file(file_path)
        total_duration_ms = len(audio)
        total_duration_sec = total_duration_ms / 1000.0

        os.makedirs(output_dir, exist_ok=True)
        original_name = Path(file_path).stem

        # If already short enough, return as single chunk
        if total_duration_sec <= self.max_duration_sec:
            return [AudioChunk(
                file_path=file_path,
                original_segment=Path(file_path).name,
                chunk_index=0,
                total_chunks=1,
                start_ms=0,
                end_ms=total_duration_ms,
                duration_sec=total_duration_sec
            )]

        # Convert to numpy for energy analysis
        samples = np.array(audio.get_array_of_samples(), dtype=np.float32)
        if audio.channels > 1:
            samples = samples.reshape(-1, audio.channels).mean(axis=1)
        samples = samples / (2**15)  # normalize int16 to float
        sr = audio.frame_rate

        # Find VAD-aware split points using energy valleys
        target_ms = int(chunk_duration_sec * 1000)
        split_points = [0]  # start of first chunk
        cursor = 0

        while cursor + target_ms < total_duration_ms:
            # Look for lowest energy near the target split point
            candidate = cursor + target_ms
            cut_ms = self._find_energy_valley(samples, sr, candidate)
            # Don't cut too short (min 5s) or too long (max 15s)
            chunk_len = cut_ms - cursor
            if chunk_len < 5000:
                cut_ms = cursor + 5000
            elif chunk_len > 15000:
                cut_ms = cursor + int(chunk_duration_sec * 1000)
            split_points.append(cut_ms)
            cursor = cut_ms

        split_points.append(total_duration_ms)  # end of last chunk

        # Build chunks from split points
        chunks = []
        num_chunks = len(split_points) - 1

        for i in range(num_chunks):
            start_ms = split_points[i]
            end_ms = split_points[i + 1]
            chunk_audio = audio[start_ms:end_ms]
            chunk_duration = len(chunk_audio) / 1000.0

            # Skip tiny trailing chunks
            if chunk_duration < self.min_duration_sec and i > 0:
                continue

            chunk_filename = f"{original_name}_chunk{i:03d}.{self.output_format}"
            chunk_path = os.path.join(output_dir, chunk_filename)
            chunk_audio.export(chunk_path, format=self.output_format)

            chunks.append(AudioChunk(
                file_path=chunk_path,
                original_segment=Path(file_path).name,
                chunk_index=i,
                total_chunks=num_chunks,
                start_ms=start_ms,
                end_ms=end_ms,
                duration_sec=chunk_duration
            ))

        return chunks
    
    def process_segments_directory(
        self,
        segments_dir: str,
        output_dir: Optional[str] = None,
        max_segments: Optional[int] = None,
        skip_short: bool = True
    ) -> List[AudioChunk]:
        """
        Process all segments in a directory, splitting as needed.
        
        Args:
            segments_dir: Directory containing audio segments
            output_dir: Directory for split chunks (default: segments_dir/chunks)
            max_segments: Maximum number of segments to process (for testing)
            skip_short: Skip segments shorter than min_duration_sec
            
        Returns:
            List of AudioChunk objects ready for transcription
        """
        if output_dir is None:
            output_dir = os.path.join(segments_dir, "chunks")
        
        # Find all audio files
        audio_extensions = {'.flac', '.wav', '.mp3', '.ogg', '.m4a'}
        segment_files = []
        
        for ext in audio_extensions:
            segment_files.extend(Path(segments_dir).glob(f"*{ext}"))
        
        # Sort by filename for consistent ordering
        segment_files = sorted(segment_files, key=lambda x: x.name)
        
        if max_segments:
            segment_files = segment_files[:max_segments]
        
        print(f"[Audio] Processing {len(segment_files)} segments...")
        
        all_chunks = []
        split_count = 0
        skipped_count = 0
        
        for segment_path in segment_files:
            duration = self.get_audio_duration(str(segment_path))
            
            # Skip very short segments
            if skip_short and duration < self.min_duration_sec:
                skipped_count += 1
                continue
            
            # Split if needed
            if duration > self.max_duration_sec:
                chunks = self.split_audio(str(segment_path), output_dir)
                split_count += 1
                print(f"[Audio] Split {segment_path.name}: {duration:.1f}s -> {len(chunks)} chunks")
            else:
                # Use original file
                chunks = [AudioChunk(
                    file_path=str(segment_path),
                    original_segment=segment_path.name,
                    chunk_index=0,
                    total_chunks=1,
                    start_ms=0,
                    end_ms=int(duration * 1000),
                    duration_sec=duration
                )]
            
            all_chunks.extend(chunks)
        
        print(f"[Audio] Processed {len(segment_files)} segments -> {len(all_chunks)} chunks")
        print(f"[Audio] Split: {split_count}, Skipped (too short): {skipped_count}")
        
        return all_chunks


def get_segment_stats(segments_dir: str) -> Dict:
    """
    Get statistics about segments in a directory.
    
    Args:
        segments_dir: Directory containing audio segments
        
    Returns:
        Dictionary with duration statistics
    """
    processor = AudioProcessor()
    
    audio_extensions = {'.flac', '.wav', '.mp3', '.ogg', '.m4a'}
    durations = []
    
    for ext in audio_extensions:
        for f in Path(segments_dir).glob(f"*{ext}"):
            try:
                durations.append(processor.get_audio_duration(str(f)))
            except Exception as e:
                print(f"Error reading {f}: {e}")
    
    if not durations:
        return {"error": "No audio files found"}
    
    return {
        "total_files": len(durations),
        "total_duration_sec": sum(durations),
        "avg_duration_sec": sum(durations) / len(durations),
        "min_duration_sec": min(durations),
        "max_duration_sec": max(durations),
        "over_10s_count": sum(1 for d in durations if d > 10),
        "under_1s_count": sum(1 for d in durations if d < 1)
    }


if __name__ == "__main__":
    # Test with a sample directory
    import sys
    if len(sys.argv) > 1:
        segments_dir = sys.argv[1]
        stats = get_segment_stats(segments_dir)
        print("\nSegment Statistics:")
        for key, value in stats.items():
            print(f"  {key}: {value}")
