"""
Training loop for LeWM TTS.
"""

import os
import json
import time
import argparse
import torch
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path

from model import build_model
from dataset import build_dataloader


def train(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

    # ─── Build model ───
    config = {
        "d_model": args.d_model,
        "nhead": args.nhead,
        "n_mels": 100,
        "text_vocab_size": 256,
        "text_encoder_layers": args.text_layers,
        "audio_encoder_layers": args.audio_layers,
        "predictor_layers": args.predictor_layers,
        "dropout": args.dropout,
        "kl_weight": args.kl_weight,
        "recon_weight": args.recon_weight,
        "spectral_weight": args.spectral_weight,
        "downsample_factor": args.downsample_factor,
        "pred_weight": args.pred_weight,
        "ema_decay": args.ema_decay,
        "free_bits": args.free_bits,
        "input_noise": 0.0,  # will be annealed during training
        "cosine_weight": args.cosine_weight,
        "n_speakers": args.n_speakers,
        "speaker_weight": args.speaker_weight,
    }
    model, config = build_model(config)
    model = model.to(device)

    # ─── Build dataloader ───
    manifest_path = os.path.join(args.data_dir, "manifest.json")
    loader, dataset = build_dataloader(
        manifest_path, batch_size=args.batch_size,
        num_workers=args.num_workers,
    )
    print(f"Training: {len(dataset)} samples, {len(loader)} batches/epoch")

    # ─── Optimizer & Scheduler ───
    optimizer = torch.optim.AdamW(
        model.parameters(), lr=args.lr,
        betas=(0.9, 0.98), weight_decay=0.01,
    )

    # Warmup + cosine decay
    total_steps = args.epochs * len(loader)
    warmup_steps = min(4000, total_steps // 10)

    def lr_lambda(step):
        if step < warmup_steps:
            return step / max(1, warmup_steps)
        progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
        return 0.5 * (1 + torch.cos(torch.tensor(progress * 3.14159)).item())

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    # ─── Logging ───
    ckpt_dir = Path(args.output_dir) / "checkpoints"
    ckpt_dir.mkdir(parents=True, exist_ok=True)
    writer = SummaryWriter(Path(args.output_dir) / "logs")

    # Save config
    with open(Path(args.output_dir) / "config.json", "w") as f:
        json.dump(config, f, indent=2)

    # ─── Resume from checkpoint ───
    start_epoch = 0
    global_step = 0
    if args.resume:
        ckpt = torch.load(args.resume, map_location=device, weights_only=False)
        model.load_state_dict(ckpt["model"])
        optimizer.load_state_dict(ckpt["optimizer"])
        scheduler.load_state_dict(ckpt["scheduler"])
        start_epoch = ckpt["epoch"] + 1
        global_step = ckpt["global_step"]
        print(f"Resumed from epoch {start_epoch}, step {global_step}")

    # ─── Training loop ───
    model.train()
    best_loss = float("inf")

    for epoch in range(start_epoch, args.epochs):
        epoch_loss = 0.0
        epoch_pred_loss = 0.0
        epoch_kl_loss = 0.0
        epoch_recon_loss = 0.0
        epoch_spectral_loss = 0.0
        epoch_speaker_loss = 0.0
        t0 = time.time()

        for batch_idx, batch in enumerate(loader):
            mel = batch["mel"].to(device)
            text_tokens = batch["text_tokens"].to(device)
            mel_mask = batch["mel_mask"].to(device)
            text_mask = batch["text_mask"].to(device)
            speaker_id = batch.get("speaker_id")
            if speaker_id is not None:
                speaker_id = speaker_id.to(device)

            # KL annealing: ramp up over first 20% of training
            kl_anneal = min(1.0, global_step / (total_steps * 0.2))
            model.kl_weight = args.kl_weight * kl_anneal

            # Input noise annealing: ramp up over first 30% of training
            # Bridges train/inference gap by simulating imperfect AR inputs
            noise_progress = min(1.0, global_step / (total_steps * 0.3))
            model.input_noise = args.input_noise_max * noise_progress

            # Scheduled sampling annealing: ramp up over 20%-80% of training
            # Starts after model has learned basic prediction, then gradually
            # trains it to handle its own imperfect outputs
            ss_progress = max(0.0, min(1.0, (global_step - total_steps * 0.2) / (total_steps * 0.6)))
            model.scheduled_sampling_rate = args.ss_max * ss_progress

            # Forward
            losses = model(mel, text_tokens, mel_mask, text_mask, speaker_id=speaker_id)

            total_loss = losses["total_loss"]
            pred_loss = losses["prediction_loss"]
            mse_loss = losses["mse_loss"]
            cosine_loss = losses["cosine_loss"]
            kl_loss = losses["kl_loss"]
            recon_loss = losses["recon_loss"]
            spectral_loss = losses["spectral_loss"]
            speaker_loss = losses["speaker_loss"]

            # Backward
            optimizer.zero_grad()
            total_loss.backward()

            # Gradient clipping
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)

            optimizer.step()
            scheduler.step()

            # Update EMA target encoder
            model.update_ema()

            # Logging
            epoch_loss += total_loss.item()
            epoch_pred_loss += pred_loss.item()
            epoch_kl_loss += kl_loss.item()
            epoch_recon_loss += recon_loss.item()
            epoch_spectral_loss += spectral_loss.item()
            epoch_speaker_loss += speaker_loss.item()
            global_step += 1

            if global_step % args.log_every == 0:
                lr = optimizer.param_groups[0]["lr"]
                writer.add_scalar("train/total_loss", total_loss.item(), global_step)
                writer.add_scalar("train/prediction_loss", pred_loss.item(), global_step)
                writer.add_scalar("train/mse_loss", mse_loss.item(), global_step)
                writer.add_scalar("train/cosine_loss", cosine_loss.item(), global_step)
                writer.add_scalar("train/kl_loss", kl_loss.item(), global_step)
                writer.add_scalar("train/recon_loss", recon_loss.item(), global_step)
                writer.add_scalar("train/spectral_loss", spectral_loss.item(), global_step)
                writer.add_scalar("train/speaker_loss", speaker_loss.item(), global_step)
                writer.add_scalar("train/grad_norm", grad_norm.item(), global_step)
                writer.add_scalar("train/lr", lr, global_step)

            if global_step % args.print_every == 0:
                elapsed = time.time() - t0
                print(
                    f"  [{epoch+1}/{args.epochs}] step {global_step} | "
                    f"loss={total_loss.item():.4f} pred={pred_loss.item():.4f} "
                    f"kl={kl_loss.item():.4f} recon={recon_loss.item():.4f} "
                    f"spec={spectral_loss.item():.4f} spk={speaker_loss.item():.4f} | "
                    f"grad={grad_norm.item():.2f} | {elapsed:.1f}s"
                )

        # Epoch summary
        n_batches = len(loader)
        avg_loss = epoch_loss / n_batches
        avg_pred = epoch_pred_loss / n_batches
        avg_kl = epoch_kl_loss / n_batches
        avg_recon = epoch_recon_loss / n_batches
        avg_spectral = epoch_spectral_loss / n_batches
        avg_speaker = epoch_speaker_loss / n_batches
        epoch_time = time.time() - t0

        print(
            f"Epoch {epoch+1}/{args.epochs} | "
            f"avg_loss={avg_loss:.4f} pred={avg_pred:.4f} kl={avg_kl:.4f} "
            f"recon={avg_recon:.4f} spec={avg_spectral:.4f} spk={avg_speaker:.4f} | {epoch_time:.1f}s"
        )

        # Save checkpoint
        if (epoch + 1) % args.save_every == 0 or (epoch + 1) == args.epochs:
            ckpt_path = ckpt_dir / f"epoch_{epoch+1:04d}.pt"
            torch.save({
                "epoch": epoch,
                "global_step": global_step,
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
                "config": config,
                "loss": avg_loss,
            }, ckpt_path)
            print(f"  Saved checkpoint: {ckpt_path}")

        # Save best
        if avg_loss < best_loss:
            best_loss = avg_loss
            best_path = ckpt_dir / "best.pt"
            torch.save({
                "epoch": epoch,
                "global_step": global_step,
                "model": model.state_dict(),
                "config": config,
                "loss": avg_loss,
            }, best_path)
            print(f"  New best model (loss={avg_loss:.4f})")

    writer.close()
    print(f"\nTraining complete. Best loss: {best_loss:.4f}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # Data
    parser.add_argument("--data_dir", default="/home/ubuntu/lewm-tts/processed_data")
    parser.add_argument("--output_dir", default="/home/ubuntu/lewm-tts/output")
    # Model
    parser.add_argument("--d_model", type=int, default=256)
    parser.add_argument("--nhead", type=int, default=4)
    parser.add_argument("--text_layers", type=int, default=4)
    parser.add_argument("--audio_layers", type=int, default=4)
    parser.add_argument("--predictor_layers", type=int, default=6)
    parser.add_argument("--dropout", type=float, default=0.1)
    parser.add_argument("--downsample_factor", type=int, default=4)
    parser.add_argument("--kl_weight", type=float, default=0.05)
    parser.add_argument("--recon_weight", type=float, default=1.0)
    parser.add_argument("--spectral_weight", type=float, default=0.5)
    parser.add_argument("--pred_weight", type=float, default=10.0)
    parser.add_argument("--ema_decay", type=float, default=0.998)
    parser.add_argument("--free_bits", type=float, default=2.0)
    parser.add_argument("--input_noise_max", type=float, default=0.2)
    parser.add_argument("--cosine_weight", type=float, default=1.0)
    parser.add_argument("--ss_max", type=float, default=0.5, help="Max scheduled sampling rate")
    parser.add_argument("--n_speakers", type=int, default=1)
    parser.add_argument("--speaker_weight", type=float, default=1.0)
    # Training
    parser.add_argument("--epochs", type=int, default=200)
    parser.add_argument("--batch_size", type=int, default=16)
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--grad_clip", type=float, default=1.0)
    parser.add_argument("--num_workers", type=int, default=4)
    # Logging
    parser.add_argument("--log_every", type=int, default=10)
    parser.add_argument("--print_every", type=int, default=50)
    parser.add_argument("--save_every", type=int, default=10)
    parser.add_argument("--resume", default=None)
    args = parser.parse_args()
    train(args)
