"""
Modal entrypoint for Veena3 TTS service.

Deploys a GPU-backed ASGI app with:
- True streaming TTS
- Autoscaling (0 to N containers)
- Per-container concurrency (vLLM batching)
- Memory snapshots for fast cold starts

Usage:
    # Deploy to Modal
    modal deploy veena3modal/app.py
    
    # Serve locally (for testing)
    modal serve veena3modal/app.py
    
    # Run a single function
    modal run veena3modal/app.py::tts_api
"""

from __future__ import annotations

import os
import modal

# Get the repo root for adding local sources
REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))

# === Image Definition ===
# Full dependencies for TTS inference

image = (
    modal.Image.debian_slim(python_version="3.10")
    .apt_install(
        "ffmpeg",  # Audio encoding (opus, mp3, flac, mulaw)
        "git",     # For pip git installs
    )
    .pip_install(
        # Core ML/AI
        "torch>=2.0.0",
        "torchaudio>=2.0.0",
        "transformers>=4.35.0",
        "vllm>=0.4.0",
        
        # FastAPI
        "fastapi[standard]>=0.100.0",
        "uvicorn[standard]>=0.23.0",
        
        # Audio processing
        "numpy>=1.24.0",
        "scipy>=1.10.0",
        "librosa>=0.10.0",
        "soundfile>=0.12.0",
        
        # Utils
        "pydantic>=2.0.0",
        "prometheus_client>=0.17.0",
        "supabase>=2.0.0",
        
        # vLLM dependencies
        "einops>=0.6.0",
        "einx>=0.2.0",
        
        # SparkTTS dependencies (from external/sparktts/requirements.txt)
        "omegaconf>=2.3.0",
        "safetensors>=0.5.0",
        "soxr>=0.5.0",
        
        # AP-BWE (super-resolution) dependencies
        "matplotlib>=3.8.0",
        "natsort>=8.0.0",
        "joblib>=1.0.0",
    )
    .run_commands(
        # Ensure CUDA is properly configured
        "pip install --no-cache-dir nvidia-ml-py3 || true"
    )
    # Environment settings
    .env({
        "AUTH_BYPASS_MODE": "true",  # Disable auth for testing
        # sparktts is at /root/external/sparktts, AP-BWE at /root/external/AP-BWE
        "PYTHONPATH": "/root/external/sparktts:/root/external/AP-BWE:/root",
        "_IMAGE_BUILD_VERSION": "2025-12-25-sr-chunking-support",  # Add SR support to generate_speech_chunked
    })
    # Add sparktts package - create proper structure by copying files
    # The original has symlinks that don't work with Modal mounts
    .run_commands(
        # Create sparktts package directory
        "mkdir -p /root/external/sparktts/sparktts/models",
        "mkdir -p /root/external/sparktts/sparktts/utils",
    )
    .add_local_dir(
        local_path=os.path.join(REPO_ROOT, "external", "sparktts", "sparktts"),
        remote_path="/root/external/sparktts/sparktts",
        copy=True,
    )
    .add_local_dir(
        local_path=os.path.join(REPO_ROOT, "external", "sparktts", "models"),
        remote_path="/root/external/sparktts/sparktts/models",  # Put models inside sparktts package
        copy=True,
    )
    .add_local_dir(
        local_path=os.path.join(REPO_ROOT, "external", "sparktts", "utils"),
        remote_path="/root/external/sparktts/sparktts/utils",  # Put utils inside sparktts package
        copy=True,
    )
    # Add AP-BWE for super-resolution (16kHz -> 48kHz)
    .add_local_dir(
        local_path=os.path.join(REPO_ROOT, "external", "AP-BWE"),
        remote_path="/root/external/AP-BWE",
        copy=True,
    )
    # Non-copy local mounts (MUST BE LAST - no build steps after this!)
    .add_local_python_source("veena3modal")
)

# === Volume for Model Weights ===
# Models should be uploaded to this volume before deployment

try:
    model_volume = modal.Volume.from_name("veena3-models", create_if_missing=True)
except Exception:
    model_volume = None  # Will fail at runtime if volume not available

# === Secrets (Optional) ===
# Secrets are optional - the app will run in bypass mode without them
# Create secrets in Modal dashboard: https://modal.com/secrets
# Or use modal.Secret.from_dotenv() for local testing

# === App Definition ===

app = modal.App(
    name="veena3-tts",
    image=image,
)


@app.cls(
    # === GPU Selection ===
    # L40S: Best cost/performance ratio for TTS inference
    # - 48GB VRAM, $1.50/hr (vs A100-80GB at $4.50/hr)
    # - Handles ~12 req/s with <2s p95 latency per container
    # - Memory-bound workload, so faster GPUs have diminishing returns
    gpu="L40S",
    
    # === CPU Allocation (OPTIMIZATION Dec 2025) ===
    # Explicitly reserve 4 vCPUs for the Python driver process
    # Prevents CPU throttling during vLLM scheduler + streaming loop
    # Modal default is only 0.125 cores which can bottleneck streaming
    cpu=4.0,
    
    volumes={"/models": model_volume} if model_volume else {},
    secrets=[modal.Secret.from_name("veena3-secrets")],  # Supabase credentials
    
    # === Autoscaling Configuration ===
    # See .cursor/scaling.md for detailed analysis
    # 
    # LOAD TEST RESULTS (Dec 2025):
    # - Cold boot: ~49s (memory snapshot captures imports, not GPU model)
    # - MAX SUSTAINABLE: 280 req/s @ 250 concurrent users (P95=1417ms, 100% success)
    # - Breaking point: 300+ users (429 errors start)
    # - Per container: ~14 req/s (280/20 containers)
    #
    min_containers=1,        # Keep 1 warm 24/7 for instant response
    max_containers=50,       # ~280 req/s capacity at P95 < 3s
    buffer_containers=1,     # REQUIRED: 49s cold boot too slow for burst
    scaledown_window=300,    # 5 min idle before scaledown (user changed from 120)
    
    # === Timeouts ===
    timeout=120,             # 2 min max per request (reduced from 10 min)
    startup_timeout=600,     # 10 min for model loading (reduced from 20 min)
    
    # === Memory Optimization ===
    # Memory snapshot for faster cold boots
    # Note: GPU snapshot is experimental (alpha) - disabled for stability
    enable_memory_snapshot=True,
)
# === Per-Container Concurrency ===
# vLLM continuous batching handles concurrent requests efficiently
# Increased from 12 to 20 to maximize GPU utilization while keeping P95 < 3s
@modal.concurrent(max_inputs=20, target_inputs=15)
class TTSService:
    """
    TTS service class with model lifecycle management.
    
    Uses @modal.enter for model loading (once per container).
    Memory snapshot captures imports and non-GPU state for faster restarts.
    """
    
    @modal.enter()
    def load_model(self):
        """
        Load TTS model when container starts.
        
        Called once per container lifecycle.
        Memory snapshot captures import state but not GPU memory.
        """
        import os
        import logging
        
        logging.basicConfig(level=logging.INFO)
        logger = logging.getLogger(__name__)
        
        logger.info("🚀 Starting TTS model loading...")
        
        # Set model paths from Volume mount or env vars
        model_path = os.environ.get(
            "MODEL_PATH",
            "/models/spark_tts_4speaker"
        )
        sr_path = os.environ.get(
            "AP_BWE_CHECKPOINT_DIR",
            "/models/ap_bwe/16kto48k"
        )
        
        # Check if model exists
        if not os.path.exists(model_path):
            logger.warning(f"Model path not found: {model_path}")
            logger.warning("TTS will return 503 until model is available")
            return
        
        try:
            from veena3modal.services.tts_runtime import initialize_runtime
            
            # Initialize with GPU memory optimization
            self.runtime = initialize_runtime(
                model_path=model_path,
                sr_checkpoint_dir=sr_path if os.path.exists(sr_path) else None,
                device="cuda",
                gpu_memory_utilization=0.85,
                enable_sr=os.path.exists(sr_path),
            )
            
            logger.info(f"✅ TTS model loaded: {self.runtime.model_version}")
            
        except Exception as e:
            logger.error(f"❌ Model loading failed: {e}")
            import traceback
            traceback.print_exc()
    
    @modal.asgi_app()
    def serve(self):
        """
        Return FastAPI ASGI app for serving requests.
        """
        from veena3modal.api.fastapi_app import create_app
        return create_app()


# === Standalone Function (Alternative Deployment) ===
# Use this if you don't need @modal.enter lifecycle
# Note: Prefer TTSService class for production (better model lifecycle management)

@app.function(
    gpu="L40S",  # Best cost/performance for TTS
    cpu=4.0,     # Explicit CPU allocation (OPTIMIZATION Dec 2025)
    volumes={"/models": model_volume} if model_volume else {},
    secrets=[modal.Secret.from_name("veena3-secrets")],
    min_containers=0,        # Scale to zero (no min for standalone)
    max_containers=5,        # Lower cap for standalone
    buffer_containers=1,
    scaledown_window=180,    # 3 min
    timeout=120,             # 2 min max
    startup_timeout=600,     # 10 min for model load
)
@modal.asgi_app()
def tts_api():
    """
    Standalone ASGI app function.
    
    Note: Model will be loaded on first request (higher TTFB).
    For production, use TTSService class with @modal.enter.
    """
    from veena3modal.api.fastapi_app import create_app
    return create_app()


# === Health Check Function ===

@app.function(
    # No GPU needed for health check
)
def health_check():
    """
    Simple health check that can be called from Modal dashboard.
    """
    return {
        "status": "ok",
        "service": "veena3-tts",
        "version": "0.1.0",
    }


# === Debug Function ===

@app.function(
    gpu="L40S",
    volumes={"/models": model_volume} if model_volume else {},
    timeout=300,
)
def debug_imports():
    """
    Debug function to test import paths and file structure.
    Run with: modal run veena3modal/app.py::debug_imports
    """
    import os
    import sys
    
    results = {
        "pythonpath": os.environ.get("PYTHONPATH", "NOT_SET"),
        "sys_path": sys.path[:10],  # First 10 entries
        "files": {},
        "import_tests": {},
    }
    
    # Check key directories
    dirs_to_check = [
        "/root",
        "/root/external",
        "/root/external/sparktts",
        "/root/external/sparktts/sparktts",
        "/root/external/sparktts/sparktts/models",
        "/root/veena3modal",
        "/models",
    ]
    
    for d in dirs_to_check:
        if os.path.exists(d):
            try:
                files = os.listdir(d)
                results["files"][d] = files[:20]  # First 20 files
            except Exception as e:
                results["files"][d] = f"ERROR: {e}"
        else:
            results["files"][d] = "DOES_NOT_EXIST"
    
    # Test imports
    import_tests = [
        "sparktts",
        "sparktts.models",
        "sparktts.models.audio_tokenizer",
        "veena3modal",
        "veena3modal.core",
        "veena3modal.processing",
        "veena3modal.audio",
    ]
    
    for module in import_tests:
        try:
            __import__(module)
            results["import_tests"][module] = "OK"
        except Exception as e:
            results["import_tests"][module] = f"FAILED: {e}"
    
    return results


# === Local Testing Entry Point ===

@app.local_entrypoint()
def main():
    """
    Local entrypoint for testing.
    
    Usage: modal run veena3modal/app.py
    """
    import json
    
    print("Veena3 TTS Modal App")
    print("=" * 40)
    print("\n🔍 Running debug_imports on GPU container...")
    
    result = debug_imports.remote()
    print("\n📊 Debug Results:")
    print(json.dumps(result, indent=2))
