#!/usr/bin/env python3
"""Utility functions for the pipeline."""

import time
import logging
from pathlib import Path
from typing import List, Dict
import torchaudio

logger = logging.getLogger("FastPipelineV5.Utils")


def generate_samples(
    audio_path: str,
    segments: List[Dict],
    config,
    output_dir: Path
) -> Dict[str, List[str]]:
    """
    Generate sample clips for each speaker for validation.
    
    Strategy: Pick 3 segments per speaker (longest, median, shortest).
    """
    logger.info("🎧 Generating samples...")
    start = time.time()
    
    samples_dir = output_dir / "speaker_samples"
    samples_dir.mkdir(exist_ok=True)
    
    waveform, sr = torchaudio.load(audio_path)
    
    # Group by speaker
    speaker_segments = {}
    for seg in segments:
        if seg.get('status') == 'usable' and seg['speaker'] not in ['OVERLAP', 'NON_SPEECH']:
            spk = seg['speaker']
            if spk not in speaker_segments:
                speaker_segments[spk] = []
            speaker_segments[spk].append(seg)
    
    sample_paths = {}
    
    for speaker, segs in speaker_segments.items():
        # Sort by duration
        sorted_segs = sorted(segs, key=lambda x: x['duration'], reverse=True)
        
        # Pick representative samples
        n = len(sorted_segs)
        if n >= 3:
            indices = [0, n // 2, n - 1]  # longest, median, shortest
        else:
            indices = list(range(n))
        
        sample_paths[speaker] = []
        
        for idx, seg_idx in enumerate(indices[:config.clips_per_speaker]):
            seg = sorted_segs[seg_idx]
            
            # Extract clip (max 5s)
            clip_dur = min(5.0, max(2.0, seg['duration']))
            start_sample = int(seg['start'] * sr)
            end_sample = int(min(seg['start'] + clip_dur, seg['end']) * sr)
            
            clip = waveform[:, start_sample:end_sample]
            clip_path = samples_dir / f"{speaker}_sample_{idx + 1}.wav"
            torchaudio.save(str(clip_path), clip, sr)
            sample_paths[speaker].append(str(clip_path))
    
    elapsed = time.time() - start
    total = sum(len(v) for v in sample_paths.values())
    logger.info(f"✅ Samples: {elapsed:.1f}s | {total} clips")
    
    return sample_paths


def cleanup_chunks(audio_path: str):
    """Remove temporary chunk files."""
    chunk_dir = Path(audio_path).parent / "chunks"
    if chunk_dir.exists():
        for f in chunk_dir.iterdir():
            f.unlink()
        chunk_dir.rmdir()
        logger.info(f"   Cleaned up chunks directory")

