#!/usr/bin/env bash
# ============================================================================
# Veena3 TTS Local Setup
#
# Sets up everything needed to run the TTS server locally:
# 1. Creates/activates Python venv
# 2. Installs dependencies
# 3. Downloads model weights from HuggingFace
# 4. Validates GPU + CUDA availability
# 5. Runs a quick smoke test
#
# Usage:
#   bash scripts/setup_local.sh              # Full setup
#   bash scripts/setup_local.sh --skip-model # Skip model download
#   bash scripts/setup_local.sh --gpu-check  # Only check GPU
# ============================================================================

set -euo pipefail

REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
VENV_DIR="${REPO_ROOT}/venv"
MODEL_DIR="${REPO_ROOT}/models/spark_tts_4speaker"
HF_REPO="BayAreaBoys/spark_tts_4speaker"
PYTHON_VERSION_REQUIRED="3.10"

# Colors for output
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
BLUE='\033[0;34m'
NC='\033[0m' # No Color

log_info()  { echo -e "${BLUE}[INFO]${NC}  $*"; }
log_ok()    { echo -e "${GREEN}[OK]${NC}    $*"; }
log_warn()  { echo -e "${YELLOW}[WARN]${NC}  $*"; }
log_err()   { echo -e "${RED}[ERROR]${NC} $*"; }

# Parse flags
SKIP_MODEL=false
GPU_CHECK_ONLY=false
for arg in "$@"; do
    case "$arg" in
        --skip-model)  SKIP_MODEL=true ;;
        --gpu-check)   GPU_CHECK_ONLY=true ;;
    esac
done

# ============================================================================
# 1. GPU Check
# ============================================================================
echo ""
echo "============================================================"
echo "  Veena3 TTS Local Setup"
echo "============================================================"

log_info "Checking GPU..."
if command -v nvidia-smi &>/dev/null; then
    GPU_NAME=$(nvidia-smi --query-gpu=name --format=csv,noheader 2>/dev/null | head -1)
    GPU_MEM=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader 2>/dev/null | head -1)
    GPU_DRIVER=$(nvidia-smi --query-gpu=driver_version --format=csv,noheader 2>/dev/null | head -1)
    CUDA_VER=$(nvidia-smi | grep "CUDA Version" | awk '{print $9}')
    log_ok "GPU: ${GPU_NAME} (${GPU_MEM})"
    log_ok "Driver: ${GPU_DRIVER}, CUDA: ${CUDA_VER}"
else
    log_warn "nvidia-smi not found. GPU inference will not work."
    log_warn "CPU-only mode available with --device cpu (very slow)"
fi

if [ "$GPU_CHECK_ONLY" = true ]; then
    exit 0
fi

# ============================================================================
# 2. Python Virtual Environment
# ============================================================================
log_info "Setting up Python virtual environment..."

if [ ! -d "$VENV_DIR" ]; then
    log_info "Creating venv at ${VENV_DIR}..."
    python3 -m venv "$VENV_DIR"
    log_ok "Virtual environment created"
else
    log_ok "Virtual environment already exists at ${VENV_DIR}"
fi

# Activate venv
source "${VENV_DIR}/bin/activate"
log_ok "Activated venv: $(which python)"
log_ok "Python version: $(python --version)"

# Upgrade pip
pip install --quiet --upgrade pip wheel setuptools

# ============================================================================
# 3. Install Dependencies
# ============================================================================
log_info "Installing dependencies..."

# Core requirements
pip install --quiet -r "${REPO_ROOT}/requirements.txt"

# FastAPI + uvicorn (not in requirements.txt since Modal bundles them)
pip install --quiet \
    "fastapi[standard]>=0.100.0" \
    "uvicorn[standard]>=0.23.0" \
    "prometheus_client>=0.17.0" \
    "supabase>=2.0.0"

# HuggingFace Hub for model download
pip install --quiet "huggingface-hub[cli]>=0.24.0"

# SparkTTS extra deps
pip install --quiet \
    "omegaconf>=2.3.0" \
    "safetensors>=0.5.0" \
    "soxr>=0.5.0"

# AP-BWE deps
pip install --quiet \
    "matplotlib>=3.8.0" \
    "natsort>=8.0.0" \
    "joblib>=1.0.0"

# Einops (vLLM + model deps)
pip install --quiet \
    "einops>=0.6.0" \
    "einx>=0.2.0"

log_ok "All dependencies installed"

# ============================================================================
# 4. Validate Key Imports
# ============================================================================
log_info "Validating imports..."

python -c "
import torch
print(f'  torch: {torch.__version__}, CUDA: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'  GPU: {torch.cuda.get_device_name(0)}')
    print(f'  VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')
" || { log_err "PyTorch import failed"; exit 1; }

python -c "import vllm; print(f'  vllm: {vllm.__version__}')" \
    || { log_err "vLLM import failed"; exit 1; }

python -c "import fastapi; print(f'  fastapi: {fastapi.__version__}')" \
    || { log_err "FastAPI import failed"; exit 1; }

python -c "
import sys
sys.path.insert(0, '${REPO_ROOT}/external/sparktts')
import sparktts
print(f'  sparktts: available')
" || { log_warn "sparktts import failed (may need model files)"; }

log_ok "Core imports validated"

# ============================================================================
# 5. Download Model
# ============================================================================
if [ "$SKIP_MODEL" = false ]; then
    log_info "Checking model at ${MODEL_DIR}..."

    if [ -f "${MODEL_DIR}/config.json" ]; then
        log_ok "Model already present"
    else
        log_info "Downloading model from ${HF_REPO}..."
        log_info "This may take a few minutes on first run..."

        # Load HF token from .env if present
        if [ -f "${REPO_ROOT}/.env" ]; then
            export $(grep -v '^#' "${REPO_ROOT}/.env" | xargs)
        fi

        mkdir -p "${MODEL_DIR}"

        python -c "
from huggingface_hub import snapshot_download
snapshot_download(
    repo_id='${HF_REPO}',
    local_dir='${MODEL_DIR}',
    token='${HF_TOKEN:-}' or None,
)
print('Model download complete')
"
        if [ -f "${MODEL_DIR}/config.json" ]; then
            log_ok "Model downloaded successfully"
        else
            log_err "Model download may have failed - config.json not found"
            exit 1
        fi
    fi
else
    log_warn "Skipping model download (--skip-model)"
fi

# ============================================================================
# 6. Check ffmpeg (needed for audio encoding: opus, mp3, flac, mulaw)
# ============================================================================
log_info "Checking ffmpeg..."
if command -v ffmpeg &>/dev/null; then
    FFMPEG_VER=$(ffmpeg -version 2>&1 | head -1 | awk '{print $3}')
    log_ok "ffmpeg: ${FFMPEG_VER}"
else
    log_warn "ffmpeg not found. Install with: sudo apt install ffmpeg"
    log_warn "Audio format encoding (opus, mp3, flac, mulaw) will not work"
fi

# ============================================================================
# 7. Summary
# ============================================================================
echo ""
echo "============================================================"
echo "  Setup Complete!"
echo "============================================================"
echo ""
echo "  To activate the environment:"
echo "    source venv/bin/activate"
echo ""
echo "  To start the local TTS server:"
echo "    python -m veena3modal.local_server"
echo ""
echo "  To start with custom settings:"
echo "    python -m veena3modal.local_server --port 8080 --gpu-memory 0.5"
echo ""
echo "  Quick test after server starts:"
echo "    curl http://localhost:8000/v1/tts/health"
echo ""
echo "    curl -X POST http://localhost:8000/v1/tts/generate \\"
echo "      -H 'Content-Type: application/json' \\"
echo "      -d '{\"text\": \"Hello, this is a test.\", \"speaker\": \"Mitra\"}' \\"
echo "      --output test.wav"
echo ""
echo "============================================================"
