"""
Local entry point for Veena3 TTS - runs the same FastAPI app without Modal.

Replicates what Modal's @modal.enter + @modal.asgi_app does:
1. Sets up PYTHONPATH for vendored deps (sparktts, AP-BWE)
2. Downloads model weights from HuggingFace if not present
3. Initializes TTS runtime (vLLM engine, BiCodec decoder, pipelines)
4. Serves the FastAPI app via uvicorn

Usage:
    # Basic (auto-downloads model, uses defaults)
    python -m veena3modal.local_server

    # Custom model path (skip download)
    python -m veena3modal.local_server --model-path /path/to/spark_tts_4speaker

    # Custom GPU memory + port
    python -m veena3modal.local_server --gpu-memory 0.5 --port 8080

    # With super-resolution (48kHz output)
    python -m veena3modal.local_server --enable-sr --sr-path /path/to/ap_bwe/16kto48k

    # CPU-only (no GPU, for testing only - very slow)
    python -m veena3modal.local_server --device cpu

Environment Variables (override defaults):
    MODEL_PATH              - Path to Spark TTS model directory
    HF_TOKEN                - HuggingFace token for private models
    AP_BWE_CHECKPOINT_DIR   - Path to super-resolution checkpoints
    AUTH_BYPASS_MODE        - "true" to disable auth (default for local)
    GPU_MEMORY_UTILIZATION  - vLLM GPU memory fraction (default: 0.25)
"""

from __future__ import annotations

import argparse
import logging
import os
import sys
import time

# === Path Setup (MUST happen before any veena3modal imports) ===
# Vendored deps: sparktts and AP-BWE live in external/ relative to repo root
REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
EXTERNAL_DIR = os.path.join(REPO_ROOT, "external")

# sparktts package path: external/sparktts (contains sparktts/ subpackage)
SPARKTTS_PATH = os.path.join(EXTERNAL_DIR, "sparktts")
# AP-BWE path: external/AP-BWE
AP_BWE_PATH = os.path.join(EXTERNAL_DIR, "AP-BWE")

for path in [SPARKTTS_PATH, AP_BWE_PATH, REPO_ROOT]:
    if path not in sys.path:
        sys.path.insert(0, path)

# Also set PYTHONPATH for subprocesses (e.g., ffmpeg wrappers)
existing_pythonpath = os.environ.get("PYTHONPATH", "")
new_paths = f"{SPARKTTS_PATH}:{AP_BWE_PATH}:{REPO_ROOT}"
if existing_pythonpath:
    os.environ["PYTHONPATH"] = f"{new_paths}:{existing_pythonpath}"
else:
    os.environ["PYTHONPATH"] = new_paths

logger = logging.getLogger("veena3.local")

# Default local model directory (relative to repo)
DEFAULT_LOCAL_MODEL_DIR = os.path.join(REPO_ROOT, "models", "spark_tts_4speaker")

# Model sources (tried in order):
# 1. Local path (--model-path or MODEL_PATH env)
# 2. HuggingFace download (private repo: BayAreaBoys/spark_tts_4speaker)
# 3. Base Spark TTS model (public: SparkAudio/Spark-TTS-0.5B)
PRIVATE_HF_REPO = "BayAreaBoys/spark_tts_4speaker"
PUBLIC_FALLBACK_REPO = "SparkAudio/Spark-TTS-0.5B"


def download_model(
    hf_repo: str = PRIVATE_HF_REPO,
    local_dir: str = DEFAULT_LOCAL_MODEL_DIR,
    hf_token: str | None = None,
) -> str:
    """
    Download Spark TTS model from HuggingFace if not already present.

    Strategy:
    1. Try private repo (BayAreaBoys/spark_tts_4speaker) with HF token
    2. Fall back to public base model (SparkAudio/Spark-TTS-0.5B)
    3. Provide instructions for Modal volume download if both fail

    Returns the resolved local path.
    """
    # Check for any valid marker file (config.json at root, config.yaml, or LLM/config.json)
    marker_files = ["config.json", "config.yaml", "LLM/config.json"]
    if any(os.path.exists(os.path.join(local_dir, m)) for m in marker_files):
        logger.info(f"Model already present at {local_dir}")
        return local_dir

    try:
        from huggingface_hub import snapshot_download
    except ImportError:
        logger.error(
            "huggingface_hub not installed. Install with: pip install huggingface-hub\n"
            "Or download model manually and pass --model-path"
        )
        sys.exit(1)

    os.makedirs(local_dir, exist_ok=True)

    # Try private repo first
    logger.info(f"Attempting download from {hf_repo}...")
    try:
        downloaded_path = snapshot_download(
            repo_id=hf_repo,
            local_dir=local_dir,
            token=hf_token,
            ignore_patterns=["*.bin", "training_args.bin", "optimizer.pt"],
        )
        logger.info(f"Model downloaded to {downloaded_path}")
        return downloaded_path
    except Exception as e:
        logger.warning(f"Private repo download failed: {e}")

    # Try public fallback
    if hf_repo != PUBLIC_FALLBACK_REPO:
        logger.info(f"Trying public base model: {PUBLIC_FALLBACK_REPO}...")
        try:
            downloaded_path = snapshot_download(
                repo_id=PUBLIC_FALLBACK_REPO,
                local_dir=local_dir,
                token=hf_token,
            )
            logger.info(f"Base model downloaded to {downloaded_path}")
            logger.warning(
                "Using public base SparkTTS model (no custom speakers). "
                "For full speaker support, download your fine-tuned model."
            )
            return downloaded_path
        except Exception as e2:
            logger.warning(f"Public model download also failed: {e2}")

    # All downloads failed - provide manual instructions
    logger.error(
        "\nModel download failed. Options:\n"
        "  1. Download from Modal volume:\n"
        "     modal volume get veena3-models spark_tts_4speaker models/spark_tts_4speaker\n\n"
        "  2. Download from HuggingFace (if you have access):\n"
        f"     huggingface-cli download {PRIVATE_HF_REPO} --local-dir {local_dir}\n\n"
        "  3. Use base SparkTTS model:\n"
        f"     huggingface-cli download {PUBLIC_FALLBACK_REPO} --local-dir {local_dir}\n\n"
        "  4. Specify a custom path:\n"
        "     python -m veena3modal.local_server --model-path /path/to/model"
    )
    sys.exit(1)


def resolve_model_paths(model_path: str) -> tuple[str, str]:
    """
    Resolve LLM and BiCodec paths from a Spark TTS model directory.

    Spark TTS model structure (from SparkAudio/Spark-TTS-0.5B):
      model_path/
        LLM/          <- Language model (vLLM loads this)
        BiCodec/      <- Audio tokenizer/decoder
        config.yaml   <- Top-level config

    Fine-tuned models may have LLM files at root (flat structure).
    This function auto-detects both layouts.

    Returns:
        (llm_path, bicodec_path) resolved paths
    """
    # Check for subdirectory layout (SparkAudio/Spark-TTS-0.5B style)
    llm_subdir = os.path.join(model_path, "LLM")
    bicodec_subdir = os.path.join(model_path, "BiCodec")

    if os.path.isdir(llm_subdir) and os.path.exists(os.path.join(llm_subdir, "config.json")):
        llm_path = llm_subdir
    else:
        # Flat layout: LLM files at root
        llm_path = model_path

    # BiCodec path is always the root (it contains BiCodec/ subdir)
    bicodec_path = model_path

    logger.info(f"Resolved model paths:")
    logger.info(f"  LLM (vLLM):    {llm_path}")
    logger.info(f"  BiCodec:       {bicodec_path}")

    return llm_path, bicodec_path


def initialize_local_runtime(
    model_path: str,
    device: str = "cuda",
    gpu_memory_utilization: float = 0.85,
    enable_sr: bool = False,
    sr_checkpoint_dir: str | None = None,
    hf_token: str | None = None,
    num_engines: int = 1,
):
    """
    Initialize TTS runtime for local serving.

    This replaces Modal's @modal.enter lifecycle hook.
    Loads: vLLM engine, BiCodec decoder, prompt builder, pipelines.
    """
    from veena3modal.services.tts_runtime import initialize_runtime

    # Resolve LLM vs BiCodec paths from model directory
    llm_path, bicodec_path = resolve_model_paths(model_path)

    start = time.time()
    logger.info("=" * 60)
    logger.info("Initializing Veena3 TTS Runtime (local)")
    logger.info(f"  Model:       {model_path}")
    logger.info(f"  LLM path:    {llm_path}")
    logger.info(f"  BiCodec:     {bicodec_path}")
    logger.info(f"  Device:      {device}")
    logger.info(f"  GPU Memory:  {gpu_memory_utilization:.0%} (total)")
    logger.info(f"  Engines:     {num_engines}")
    logger.info(f"  SR:          {'enabled' if enable_sr else 'disabled'}")
    logger.info("=" * 60)

    runtime = initialize_runtime(
        model_path=llm_path,
        bicodec_path=bicodec_path,
        sr_checkpoint_dir=sr_checkpoint_dir if enable_sr else None,
        device=device,
        hf_token=hf_token,
        gpu_memory_utilization=gpu_memory_utilization,
        enable_sr=enable_sr,
        num_engines=num_engines,
    )

    elapsed = time.time() - start
    logger.info(f"Runtime initialized in {elapsed:.1f}s")
    logger.info(f"  Model version: {runtime.model_version}")
    logger.info(f"  Streaming: {'available' if runtime.streaming_pipeline else 'unavailable'}")
    logger.info(f"  Super-resolution: {'loaded' if runtime.sr_service and runtime.sr_service.is_loaded else 'disabled'}")

    return runtime


def create_local_app():
    """
    Create the FastAPI app for local serving.

    Same app factory as Modal deployment, but called from local context.
    """
    from veena3modal.api.fastapi_app import create_app
    return create_app()


def parse_args():
    """Parse CLI arguments for local server configuration."""
    parser = argparse.ArgumentParser(
        description="Veena3 TTS Local Server - GPU-accelerated text-to-speech",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  python -m veena3modal.local_server                          # Auto-download model, serve on :8000
  python -m veena3modal.local_server --port 8080              # Custom port
  python -m veena3modal.local_server --model-path ./my_model  # Use local model
  python -m veena3modal.local_server --gpu-memory 0.5         # Limit GPU memory
  python -m veena3modal.local_server --workers 2              # Multiple workers (careful: each loads model)
        """,
    )

    # Model configuration
    model_group = parser.add_argument_group("Model")
    model_group.add_argument(
        "--model-path",
        type=str,
        default=os.environ.get("MODEL_PATH", DEFAULT_LOCAL_MODEL_DIR),
        help=f"Path to Spark TTS model (default: {DEFAULT_LOCAL_MODEL_DIR})",
    )
    model_group.add_argument(
        "--hf-repo",
        type=str,
        default=PRIVATE_HF_REPO,
        help=f"HuggingFace repo to download from (default: {PRIVATE_HF_REPO})",
    )
    model_group.add_argument(
        "--hf-token",
        type=str,
        default=os.environ.get("HF_TOKEN"),
        help="HuggingFace token (default: from HF_TOKEN env var)",
    )
    model_group.add_argument(
        "--skip-download",
        action="store_true",
        help="Don't auto-download model (fail if not found)",
    )

    # GPU configuration
    gpu_group = parser.add_argument_group("GPU")
    gpu_group.add_argument(
        "--device",
        type=str,
        default="cuda",
        choices=["cuda", "cpu"],
        help="Device for inference (default: cuda)",
    )
    gpu_group.add_argument(
        "--gpu-memory",
        type=float,
        default=float(os.environ.get("GPU_MEMORY_UTILIZATION", "0.25")),
        help="vLLM GPU memory utilization fraction (default: 0.25)",
    )
    gpu_group.add_argument(
        "--num-engines",
        type=int,
        default=int(os.environ.get("NUM_ENGINES", "1")),
        help="Number of vLLM engine instances (default: 1, use 2-3 for multi-engine mode)",
    )

    # Super-resolution
    sr_group = parser.add_argument_group("Super-Resolution")
    sr_group.add_argument(
        "--enable-sr",
        action="store_true",
        default=False,
        help="Enable AP-BWE super-resolution (16kHz -> 48kHz)",
    )
    sr_group.add_argument(
        "--sr-path",
        type=str,
        default=os.environ.get("AP_BWE_CHECKPOINT_DIR"),
        help="Path to AP-BWE checkpoint directory",
    )

    # Server configuration
    server_group = parser.add_argument_group("Server")
    server_group.add_argument(
        "--host",
        type=str,
        default="0.0.0.0",
        help="Bind address (default: 0.0.0.0)",
    )
    server_group.add_argument(
        "--port",
        type=int,
        default=int(os.environ.get("PORT", "8000")),
        help="Port to serve on (default: 8000)",
    )
    server_group.add_argument(
        "--workers",
        type=int,
        default=1,
        help="Number of uvicorn workers (default: 1, each loads full model)",
    )
    server_group.add_argument(
        "--reload",
        action="store_true",
        default=False,
        help="Enable auto-reload for development (incompatible with --workers > 1)",
    )

    # Auth
    auth_group = parser.add_argument_group("Auth")
    auth_group.add_argument(
        "--auth",
        action="store_true",
        default=False,
        help="Enable API key authentication (default: disabled for local)",
    )

    # Logging
    log_group = parser.add_argument_group("Logging")
    log_group.add_argument(
        "--log-level",
        type=str,
        default="INFO",
        choices=["DEBUG", "INFO", "WARNING", "ERROR"],
        help="Log level (default: INFO)",
    )

    return parser.parse_args()


def main():
    """
    Main entry point for local TTS server.

    Lifecycle:
    1. Parse args + configure logging
    2. Download model if needed
    3. Initialize TTS runtime (vLLM engine, decoders, pipelines)
    4. Create FastAPI app
    5. Run uvicorn
    """
    args = parse_args()

    # Configure logging
    logging.basicConfig(
        level=getattr(logging, args.log_level),
        format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )

    # Auth bypass for local development (unless --auth flag is set)
    if not args.auth:
        os.environ["AUTH_BYPASS_MODE"] = "true"
        logger.info("Auth bypass enabled (local mode). Use --auth to require API keys.")
    else:
        os.environ.setdefault("AUTH_BYPASS_MODE", "false")

    # Resolve model path
    # Spark TTS root has config.yaml; LLM subdir has config.json; fine-tuned may have config.json at root
    model_path = args.model_path
    model_exists = any(
        os.path.exists(os.path.join(model_path, marker))
        for marker in ("config.json", "config.yaml", "LLM/config.json")
    )
    if not model_exists:
        if args.skip_download:
            logger.error(f"Model not found at {model_path} and --skip-download is set")
            sys.exit(1)
        model_path = download_model(
            hf_repo=args.hf_repo,
            local_dir=model_path,
            hf_token=args.hf_token,
        )

    # Initialize TTS runtime (this is the expensive step - loads model onto GPU)
    initialize_local_runtime(
        model_path=model_path,
        device=args.device,
        gpu_memory_utilization=args.gpu_memory,
        enable_sr=args.enable_sr,
        sr_checkpoint_dir=args.sr_path,
        hf_token=args.hf_token,
        num_engines=args.num_engines,
    )

    # Run uvicorn
    # NOTE: We use the factory string so uvicorn can find the app.
    # With --reload, uvicorn reimports, but the runtime singleton persists in memory.
    import uvicorn

    logger.info("")
    logger.info("=" * 60)
    logger.info(f"Veena3 TTS Local Server starting")
    logger.info(f"  URL: http://{args.host}:{args.port}")
    logger.info(f"  Health: http://{args.host}:{args.port}/v1/tts/health")
    logger.info(f"  Generate: POST http://{args.host}:{args.port}/v1/tts/generate")
    logger.info(f"  Metrics: http://{args.host}:{args.port}/v1/tts/metrics")
    logger.info(f"  Docs: http://{args.host}:{args.port}/docs")
    logger.info("=" * 60)

    uvicorn.run(
        "veena3modal.local_server:create_local_app",
        factory=True,
        host=args.host,
        port=args.port,
        workers=args.workers,
        reload=args.reload,
        log_level=args.log_level.lower(),
        # Performance: keep-alive for connection reuse under load
        timeout_keep_alive=30,
    )


if __name__ == "__main__":
    main()
