"""
Super-resolution service for upsampling audio from 16kHz to 48kHz using AP-BWE.

Features:
- Single model instance per worker (loaded once, kept hot)
- Streaming-compatible chunk processing
- GPU acceleration with FP16 support
- Minimal latency (<10ms per chunk)
"""

import os
import sys
import json
import time
import logging
from typing import Optional, Tuple
import numpy as np
import torch
import torchaudio
import torchaudio.functional as aF
from pathlib import Path

# Ensure AP-BWE (vendored under external/) is importable in local/dev runs.
# In Modal, PYTHONPATH already includes `/root/external/AP-BWE`.
_THIS_FILE = Path(__file__).resolve()
_AP_BWE_PATH = None
for _p in _THIS_FILE.parents:
    _candidate = _p / "external" / "AP-BWE"
    if _candidate.is_dir():
        _AP_BWE_PATH = _candidate
        break
if _AP_BWE_PATH and str(_AP_BWE_PATH) not in sys.path:
    sys.path.insert(0, str(_AP_BWE_PATH))

from env import AttrDict
from models.model import APNet_BWE_Model
from datasets.dataset import amp_pha_stft, amp_pha_istft

logger = logging.getLogger(__name__)


class SuperResolutionService:
    """
    Service for audio super-resolution using AP-BWE model.
    
    This service manages:
    - Model loading and initialization
    - Device placement (GPU/CPU)
    - Chunk-based processing for streaming
    - Warmup and optimization
    """
    
    _instance = None
    _initialized = False
    _checkpoint_dir = None  # Store checkpoint_dir at class level for singleton
    
    def __new__(cls, checkpoint_dir: Optional[str] = None):
        """Singleton pattern - one SR model per worker process."""
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        # Store checkpoint_dir for __init__ to use
        if checkpoint_dir:
            cls._checkpoint_dir = checkpoint_dir
        return cls._instance
    
    def __init__(self, checkpoint_dir: Optional[str] = None):
        """Initialize the service (only runs once due to singleton)."""
        if not self._initialized:
            self.model = None
            self.config = None
            self.device = None
            self.is_loaded = False
            self.use_fp16 = False
            # Use provided checkpoint_dir, class-level stored dir, env var, or default path
            self.model_path = (
                checkpoint_dir or
                self._checkpoint_dir or
                os.environ.get('AP_BWE_CHECKPOINT_DIR') or
                "/models/ap_bwe/16kto48k"  # Modal volume path
            )
            self.chunk_size_ms = 40  # Default chunk size in milliseconds
            self._initialized = True
            
            # Performance tracking
            self.total_chunks_processed = 0
            self.total_processing_time = 0.0
            
            logger.info(f"SuperResolutionService initialized (checkpoint: {self.model_path})")
    
    def load_model(self, device: Optional[str] = None, use_fp16: bool = False) -> bool:
        """
        Load the AP-BWE model for 16kHz -> 48kHz upsampling.
        
        Args:
            device: Device to use ('cuda', 'cpu', or None for auto-detect)
            use_fp16: Use FP16 for faster inference on GPU
        
        Returns:
            True if successfully loaded, False otherwise
        """
        if self.is_loaded:
            logger.info("SR model already loaded")
            return True
        
        try:
            # Load configuration
            config_path = os.path.join(self.model_path, "config.json")
            with open(config_path, 'r') as f:
                json_config = json.load(f)
            self.config = AttrDict(json_config)
            
            # Verify this is 16k to 48k model
            assert self.config.lr_sampling_rate == 16000, f"Expected 16kHz input, got {self.config.lr_sampling_rate}"
            assert self.config.hr_sampling_rate == 48000, f"Expected 48kHz output, got {self.config.hr_sampling_rate}"
            
            # Set device
            if device is None:
                self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
            else:
                self.device = torch.device(device)
            
            self.use_fp16 = use_fp16 and self.device.type == 'cuda'
            
            logger.info(f"Loading AP-BWE model on {self.device} (FP16: {self.use_fp16})")
            
            # Initialize model
            self.model = APNet_BWE_Model(self.config).to(self.device)
            
            # Load checkpoint
            checkpoint_path = os.path.join(self.model_path, "g_16kto48k")
            checkpoint = torch.load(checkpoint_path, map_location=self.device)
            self.model.load_state_dict(checkpoint['generator'])
            
            # Set to eval mode and disable gradients
            self.model.eval()
            for param in self.model.parameters():
                param.requires_grad = False
            
            # Convert to FP16 if requested
            if self.use_fp16:
                self.model = self.model.half()
                logger.info("Model converted to FP16")
            
            # Compile model for faster inference (PyTorch 2.0+)
            # DISABLED: torch.compile adds too much overhead for streaming use case
            # The first chunk takes ~5s to compile which kills TTFB
            # if hasattr(torch, 'compile') and not self.use_fp16:
            #     try:
            #         self.model = torch.compile(self.model, mode='reduce-overhead')
            #         logger.info("Model compiled with torch.compile")
            #     except Exception as e:
            #         logger.warning(f"torch.compile failed, using eager mode: {e}")
            
            self.is_loaded = True
            logger.info(f"AP-BWE model loaded successfully from {self.model_path}")
            
            # Warmup the model
            self._warmup()
            
            return True
            
        except Exception as e:
            logger.error(f"Failed to load SR model: {e}")
            self.is_loaded = False
            return False
    
    def _warmup(self, num_warmup: int = 3):
        """
        Warmup the model with dummy inputs to trigger JIT compilation.
        
        Args:
            num_warmup: Number of warmup passes
        """
        if not self.is_loaded:
            return
        
        logger.info(f"Warming up SR model with {num_warmup} passes...")
        
        # Create dummy input (40ms at 16kHz = 640 samples)
        chunk_samples = int(self.chunk_size_ms * 16)  # 16 samples per ms at 16kHz
        dummy_audio = torch.randn(1, chunk_samples, device=self.device)
        
        # Match the model's dtype
        if self.use_fp16:
            dummy_audio = dummy_audio.half()
        else:
            dummy_audio = dummy_audio.float()
        
        with torch.no_grad():
            for i in range(num_warmup):
                start_time = time.time()
                _ = self.process_chunk(dummy_audio)
                elapsed = (time.time() - start_time) * 1000
                logger.info(f"Warmup pass {i+1}: {elapsed:.2f}ms")
        
        logger.info("SR model warmup complete")
    
    def process_chunk(self, audio_chunk: torch.Tensor) -> torch.Tensor:
        """
        Process a single audio chunk from 16kHz to 48kHz.
        
        Args:
            audio_chunk: Input audio tensor at 16kHz [batch, samples] or [samples]
        
        Returns:
            Upsampled audio tensor at 48kHz
        """
        if not self.is_loaded:
            raise RuntimeError("SR model not loaded. Call load_model() first.")
        
        with torch.no_grad():
            # Ensure correct shape [batch, samples]
            if audio_chunk.dim() == 1:
                audio_chunk = audio_chunk.unsqueeze(0)
            
            # Move to device and convert dtype if needed
            if audio_chunk.device != self.device:
                audio_chunk = audio_chunk.to(self.device)
            
            if self.use_fp16 and audio_chunk.dtype != torch.float16:
                audio_chunk = audio_chunk.half()
            elif not self.use_fp16 and audio_chunk.dtype == torch.float16:
                audio_chunk = audio_chunk.float()
            
            # First upsample to 48kHz (simple interpolation as input to model)
            audio_lr_48k = aF.resample(audio_chunk, orig_freq=16000, new_freq=48000)
            
            # Apply STFT
            amp_lr, pha_lr, _ = amp_pha_stft(
                audio_lr_48k, 
                self.config.n_fft, 
                self.config.hop_size, 
                self.config.win_size
            )
            
            # Run through model
            amp_hr, pha_hr, _ = self.model(amp_lr, pha_lr)
            
            # Apply inverse STFT
            audio_hr = amp_pha_istft(
                amp_hr, 
                pha_hr, 
                self.config.n_fft, 
                self.config.hop_size, 
                self.config.win_size
            )
            
            # Ensure output is correct length (3x input for 16k->48k)
            expected_length = audio_chunk.shape[-1] * 3
            if audio_hr.shape[-1] > expected_length:
                audio_hr = audio_hr[..., :expected_length]
            
            return audio_hr
    
    def process_bytes(
        self, 
        audio_bytes: bytes, 
        input_rate: int = 16000,
        output_rate: int = 48000
    ) -> bytes:
        """
        Process raw audio bytes from 16kHz to 48kHz.
        
        Args:
            audio_bytes: Raw PCM audio bytes (int16)
            input_rate: Input sample rate (must be 16000)
            output_rate: Output sample rate (must be 48000)
        
        Returns:
            Upsampled audio bytes (int16 PCM)
        """
        if input_rate != 16000:
            raise ValueError(f"Input rate must be 16000, got {input_rate}")
        if output_rate != 48000:
            raise ValueError(f"Output rate must be 48000, got {output_rate}")
        
        if not self.is_loaded:
            raise RuntimeError("SR model not loaded. Call load_model() first.")
        
        start_time = time.time()
        
        # Convert bytes to numpy array (int16 PCM)
        audio_np = np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) / 32768.0
        
        # Convert to torch tensor
        audio_tensor = torch.from_numpy(audio_np).to(self.device)
        
        # Process through model
        audio_hr = self.process_chunk(audio_tensor)
        
        # Convert back to int16 PCM bytes
        audio_hr_np = audio_hr.squeeze().cpu().numpy()
        audio_hr_int16 = np.clip(audio_hr_np * 32768.0, -32768, 32767).astype(np.int16)
        audio_hr_bytes = audio_hr_int16.tobytes()
        
        # Track performance
        processing_time = time.time() - start_time
        self.total_chunks_processed += 1
        self.total_processing_time += processing_time
        
        # Log performance occasionally
        if self.total_chunks_processed % 100 == 0:
            avg_time = (self.total_processing_time / self.total_chunks_processed) * 1000
            logger.info(
                f"SR performance: {self.total_chunks_processed} chunks, "
                f"avg {avg_time:.2f}ms/chunk"
            )
        
        return audio_hr_bytes
    
    def get_stats(self) -> dict:
        """Get performance statistics."""
        if self.total_chunks_processed == 0:
            return {
                "loaded": self.is_loaded,
                "device": str(self.device) if self.device else None,
                "fp16": self.use_fp16,
                "chunks_processed": 0,
                "avg_processing_time_ms": 0.0,
            }
        
        return {
            "loaded": self.is_loaded,
            "device": str(self.device) if self.device else None, 
            "fp16": self.use_fp16,
            "chunks_processed": self.total_chunks_processed,
            "avg_processing_time_ms": (self.total_processing_time / self.total_chunks_processed) * 1000,
            "total_processing_time_s": self.total_processing_time,
        }
    
    def unload_model(self):
        """Unload the model to free memory."""
        if self.model is not None:
            del self.model
            self.model = None
        
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        self.is_loaded = False
        logger.info("SR model unloaded")


# Global instance getter
def get_sr_service() -> SuperResolutionService:
    """Get the singleton SuperResolutionService instance."""
    return SuperResolutionService()
