#!/usr/bin/env python3
"""
Shared Audio Buffer - Load Once, Use Everywhere

=== OPTIMIZATION (v6.2) ===
Previously: Each stage (VAD, diarization, embeddings, quality) called torchaudio.load()
Now: Load audio ONCE, pass numpy array/slices to all stages

Benefits:
- Eliminates 4-5 redundant file reads
- Reduces disk I/O by ~80%
- Faster processing (no decode overhead)
- Memory-efficient (single copy)
"""

import logging
import numpy as np
import torch
import torchaudio
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple
import atexit
import shutil

logger = logging.getLogger("FastPipelineV6.AudioBuffer")


@dataclass
class AudioBuffer:
    """
    Holds loaded audio waveform and provides slicing utilities.
    
    Usage:
        buffer = AudioBuffer.from_file("audio.wav")
        vad_result = run_vad(buffer)
        embeddings = extract_embeddings(buffer, segments)
    """
    waveform_np: np.ndarray  # Shape: (samples,) - mono
    sample_rate: int
    audio_path: str  # Original path for reference
    duration: float
    
    @classmethod
    def from_file(cls, audio_path: str) -> 'AudioBuffer':
        """Load audio file into buffer."""
        logger.info(f"📥 Loading audio into shared buffer: {Path(audio_path).name}")
        waveform, sr = torchaudio.load(audio_path)
        waveform_np = waveform.squeeze(0).numpy()  # Convert to mono numpy
        duration = len(waveform_np) / sr
        logger.info(f"   Loaded: {duration:.1f}s @ {sr}Hz ({len(waveform_np):,} samples)")
        return cls(
            waveform_np=waveform_np,
            sample_rate=sr,
            audio_path=audio_path,
            duration=duration
        )
    
    def get_slice(self, start_sec: float, end_sec: float) -> np.ndarray:
        """Get audio slice by time range (seconds)."""
        start_sample = int(start_sec * self.sample_rate)
        end_sample = int(end_sec * self.sample_rate)
        return self.waveform_np[start_sample:end_sample]
    
    def get_slice_samples(self, start_sample: int, end_sample: int) -> np.ndarray:
        """Get audio slice by sample indices."""
        return self.waveform_np[start_sample:end_sample]
    
    def get_torch_tensor(self, start_sec: float = 0, end_sec: float = None) -> torch.Tensor:
        """Get slice as torch tensor for GPU processing."""
        if end_sec is None:
            end_sec = self.duration
        chunk = self.get_slice(start_sec, end_sec)
        return torch.from_numpy(chunk).float()
    
    def get_full_waveform_2d(self) -> Tuple[torch.Tensor, int]:
        """Get full waveform as 2D tensor (1, samples) for torchaudio compatibility."""
        return torch.from_numpy(self.waveform_np).unsqueeze(0), self.sample_rate


class TempDirManager:
    """
    Failure-safe temporary directory manager.
    
    === OPTIMIZATION (v6.2) ===
    Registers atexit handler to clean up temp files even on crash.
    Prevents disk quota exhaustion from orphaned chunk files.
    """
    _instance = None
    _temp_dirs: list = []
    
    def __new__(cls):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
            cls._instance._temp_dirs = []
            # Register cleanup on exit (normal or exception)
            atexit.register(cls._instance.cleanup_all)
        return cls._instance
    
    def register(self, path: Path):
        """Register a temp directory for cleanup on exit."""
        self._temp_dirs.append(path)
    
    def unregister(self, path: Path):
        """Unregister a temp directory (when successfully cleaned)."""
        if path in self._temp_dirs:
            self._temp_dirs.remove(path)
    
    def cleanup_all(self):
        """Clean up all registered temp directories."""
        for path in self._temp_dirs[:]:  # Copy list to avoid modification during iteration
            try:
                if path.exists():
                    shutil.rmtree(path, ignore_errors=True)
                    logger.debug(f"   Cleaned up temp dir: {path}")
                self._temp_dirs.remove(path)
            except Exception as e:
                logger.warning(f"   Failed to cleanup {path}: {e}")
    
    def cleanup_for_audio(self, audio_path: str):
        """Clean up temp dirs associated with an audio file."""
        chunk_dir = Path(audio_path).parent / "chunks"
        if chunk_dir.exists():
            try:
                shutil.rmtree(chunk_dir)
                self.unregister(chunk_dir)
                logger.debug(f"   Cleaned up chunks: {chunk_dir}")
            except Exception as e:
                logger.warning(f"   Failed to cleanup chunks: {e}")


# Global instance
TEMP_MANAGER = TempDirManager()


def get_vram_info() -> Tuple[float, float, float]:
    """
    Get GPU memory info for memory-aware batching.
    
    === OPTIMIZATION (v6.2) ===
    Use torch.cuda.mem_get_info() to query actual VRAM availability.
    Scale batch sizes dynamically based on available memory.
    
    Returns:
        (free_gb, total_gb, utilization_pct)
    """
    if not torch.cuda.is_available():
        return (0.0, 0.0, 0.0)
    
    try:
        free, total = torch.cuda.mem_get_info()
        free_gb = free / (1024**3)
        total_gb = total / (1024**3)
        utilization = (1 - free / total) * 100
        return (free_gb, total_gb, utilization)
    except Exception:
        # Fallback for older PyTorch
        total = torch.cuda.get_device_properties(0).total_memory
        reserved = torch.cuda.memory_reserved(0)
        free = total - reserved
        return (free / (1024**3), total / (1024**3), (reserved / total) * 100)


def compute_optimal_batch_size(
    num_items: int,
    avg_item_samples: int,
    base_batch: int = 8,
    min_batch: int = 2,
    max_batch: int = 64,
    vram_headroom_gb: float = 2.0
) -> int:
    """
    Compute optimal batch size based on available VRAM.
    
    === OPTIMIZATION (v6.2) ===
    Instead of fixed batch sizes, dynamically scale based on:
    - Available VRAM (torch.cuda.mem_get_info)
    - Average item size (samples)
    - Safety headroom for GPU operations
    
    Args:
        num_items: Total items to process
        avg_item_samples: Average samples per item
        base_batch: Default batch size
        min_batch: Minimum batch size
        max_batch: Maximum batch size
        vram_headroom_gb: Reserved VRAM headroom
    
    Returns:
        Optimal batch size
    """
    free_gb, total_gb, _ = get_vram_info()
    
    if free_gb == 0:  # CPU-only
        return min(base_batch, num_items)
    
    # Available VRAM for batching (minus headroom)
    available_gb = max(0.5, free_gb - vram_headroom_gb)
    
    # Estimate memory per batch item (float32 audio + embedding overhead)
    # ~4 bytes per sample + ~2x overhead for model activations
    bytes_per_item = avg_item_samples * 4 * 3  # Conservative 3x multiplier
    gb_per_item = bytes_per_item / (1024**3)
    
    # Calculate max batch that fits in available VRAM
    if gb_per_item > 0:
        vram_limited_batch = int(available_gb / gb_per_item)
    else:
        vram_limited_batch = max_batch
    
    # Clamp to reasonable range
    optimal = max(min_batch, min(vram_limited_batch, max_batch, num_items))
    
    logger.debug(f"   VRAM-aware batch: {optimal} (free: {free_gb:.1f}GB, "
                f"per-item: {gb_per_item*1000:.1f}MB)")
    
    return optimal
