#!/usr/bin/env python3
"""
Deploy Maya3 Pipeline to all configured clusters.

Usage:
    python deploy_all.py --deploy          # Deploy code to all clusters
    python deploy_all.py --start           # Start workers on all clusters
    python deploy_all.py --stop            # Stop workers on all clusters
    python deploy_all.py --status          # Check status of all clusters
    python deploy_all.py --logs <cluster>  # Tail logs from a cluster
"""

import os
import sys
import json
import argparse
import subprocess
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed

SCRIPT_DIR = Path(__file__).parent
CONFIG_FILE = SCRIPT_DIR / "cluster_config.json"


def load_config():
    """Load cluster configuration."""
    if not CONFIG_FILE.exists():
        print(f"Error: Config file not found: {CONFIG_FILE}")
        print("Please copy cluster_config.json and fill in your cluster details")
        sys.exit(1)

    with open(CONFIG_FILE) as f:
        return json.load(f)


def run_ssh_command(cluster: dict, command: str, timeout: int = 300) -> tuple:
    """Run SSH command on a cluster."""
    ssh_opts = f"-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -i {cluster['ssh_key']}"
    ssh_cmd = f"ssh {ssh_opts} {cluster['user']}@{cluster['host']} '{command}'"

    try:
        result = subprocess.run(
            ssh_cmd,
            shell=True,
            capture_output=True,
            text=True,
            timeout=timeout
        )
        return (cluster['name'], result.returncode == 0, result.stdout, result.stderr)
    except subprocess.TimeoutExpired:
        return (cluster['name'], False, "", "Timeout")
    except Exception as e:
        return (cluster['name'], False, "", str(e))


def deploy_cluster(cluster: dict) -> tuple:
    """Deploy code to a single cluster."""
    script = SCRIPT_DIR / "deploy_cluster.sh"
    cmd = f"bash {script} {cluster['host']} {cluster['ssh_key']} {cluster['user']}"

    try:
        result = subprocess.run(
            cmd,
            shell=True,
            capture_output=True,
            text=True,
            timeout=600
        )
        return (cluster['name'], result.returncode == 0, result.stdout, result.stderr)
    except Exception as e:
        return (cluster['name'], False, "", str(e))


def start_workers(cluster: dict, config: dict) -> tuple:
    """Start workers on a single cluster."""
    proc = config.get('processing', {})
    cmd = f"""
cd /workspace/maya3_data && source .env && \
pkill -f 'massive_process.py' || true && sleep 2 && \
nohup .venv/bin/python massive_process.py \
    --supabase-queue \
    --r2-upload \
    --background-export \
    --export-workers {proc.get('export_workers', 4)} \
    --lease-duration {proc.get('lease_duration_sec', 900)} \
    --poll-interval {proc.get('poll_interval_sec', 10)} \
    > /tmp/maya3_workers.log 2>&1 &
echo $!
"""
    return run_ssh_command(cluster, cmd)


def stop_workers(cluster: dict) -> tuple:
    """Stop workers on a single cluster."""
    cmd = "pkill -f 'massive_process.py' && echo 'Stopped' || echo 'No workers running'"
    return run_ssh_command(cluster, cmd)


def get_status(cluster: dict) -> tuple:
    """Get status of a single cluster."""
    cmd = """
echo "=== GPU Status ==="
nvidia-smi --query-gpu=index,utilization.gpu,memory.used,memory.total --format=csv,noheader 2>/dev/null || echo "GPU info unavailable"

echo ""
echo "=== Worker Processes ==="
pgrep -af "massive_process.py" || echo "No workers running"

echo ""
echo "=== Recent Log ==="
tail -5 /tmp/maya3_workers.log 2>/dev/null || echo "No log file"
"""
    return run_ssh_command(cluster, cmd, timeout=30)


def get_logs(cluster: dict, lines: int = 50) -> tuple:
    """Get recent logs from a cluster."""
    cmd = f"tail -{lines} /tmp/maya3_workers.log"
    return run_ssh_command(cluster, cmd, timeout=30)


def parallel_execute(clusters: list, func, config: dict = None):
    """Execute function on all clusters in parallel."""
    results = []

    with ThreadPoolExecutor(max_workers=len(clusters)) as executor:
        if config:
            futures = {executor.submit(func, c, config): c for c in clusters}
        else:
            futures = {executor.submit(func, c): c for c in clusters}

        for future in as_completed(futures):
            result = future.result()
            results.append(result)

            name, success, stdout, stderr = result
            status = "✓" if success else "✗"
            print(f"\n[{status}] {name}")
            if stdout:
                print(stdout[:500])
            if stderr and not success:
                print(f"Error: {stderr[:200]}")

    return results


def main():
    parser = argparse.ArgumentParser(description="Maya3 Multi-Cluster Deployment")
    parser.add_argument('--deploy', action='store_true', help='Deploy code to all clusters')
    parser.add_argument('--start', action='store_true', help='Start workers on all clusters')
    parser.add_argument('--stop', action='store_true', help='Stop workers on all clusters')
    parser.add_argument('--status', action='store_true', help='Check status of all clusters')
    parser.add_argument('--logs', type=str, metavar='CLUSTER', help='Tail logs from a specific cluster')
    parser.add_argument('--cluster', type=str, help='Target specific cluster by name')

    args = parser.parse_args()

    if not any([args.deploy, args.start, args.stop, args.status, args.logs]):
        parser.print_help()
        sys.exit(1)

    config = load_config()
    clusters = config['clusters']

    # Filter to specific cluster if requested
    if args.cluster:
        clusters = [c for c in clusters if c['name'] == args.cluster]
        if not clusters:
            print(f"Error: Cluster '{args.cluster}' not found in config")
            sys.exit(1)

    print(f"Targeting {len(clusters)} cluster(s)")
    print("=" * 50)

    if args.deploy:
        print("Deploying to clusters...")
        parallel_execute(clusters, deploy_cluster)

    elif args.start:
        print("Starting workers...")
        parallel_execute(clusters, start_workers, config)

    elif args.stop:
        print("Stopping workers...")
        parallel_execute(clusters, stop_workers)

    elif args.status:
        print("Checking cluster status...")
        parallel_execute(clusters, get_status)

    elif args.logs:
        # Find cluster by name
        cluster = next((c for c in clusters if c['name'] == args.logs), None)
        if not cluster:
            print(f"Error: Cluster '{args.logs}' not found")
            sys.exit(1)

        name, success, stdout, stderr = get_logs(cluster, lines=100)
        print(f"=== Logs from {name} ===")
        print(stdout if success else stderr)

    print("\n" + "=" * 50)
    print("Done")


if __name__ == "__main__":
    main()
