"""
Training loop for LeWM TTS v5 (codec-based JEPA).
"""

import os
import json
import time
import argparse
import torch
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path

from model_v5 import build_model
from dataset_v5 import build_dataloader


def train(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

    config = {
        "d_model": args.d_model,
        "nhead": args.nhead,
        "codec_dim": 128,
        "text_vocab_size": 256,
        "text_encoder_layers": args.text_layers,
        "predictor_layers": args.predictor_layers,
        "dropout": args.dropout,
        "pred_weight": args.pred_weight,
        "roundtrip_weight": args.roundtrip_weight,
        "ema_decay": args.ema_decay,
        "input_noise": 0.0,
    }
    model, config = build_model(config)
    model = model.to(device)

    manifest_path = os.path.join(args.data_dir, "manifest.json")
    loader, dataset = build_dataloader(
        manifest_path, batch_size=args.batch_size,
        num_workers=args.num_workers, max_codec_frames=args.max_codec_frames,
    )
    print(f"Training: {len(dataset)} samples, {len(loader)} batches/epoch")

    optimizer = torch.optim.AdamW(
        model.parameters(), lr=args.lr,
        betas=(0.9, 0.98), weight_decay=0.01,
    )

    total_steps = args.epochs * len(loader)
    warmup_steps = min(4000, total_steps // 10)

    def lr_lambda(step):
        if step < warmup_steps:
            return step / max(1, warmup_steps)
        progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
        return 0.5 * (1 + torch.cos(torch.tensor(progress * 3.14159)).item())

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    ckpt_dir = Path(args.output_dir) / "checkpoints"
    ckpt_dir.mkdir(parents=True, exist_ok=True)
    writer = SummaryWriter(Path(args.output_dir) / "logs")

    with open(Path(args.output_dir) / "config.json", "w") as f:
        json.dump(config, f, indent=2)

    start_epoch = 0
    global_step = 0
    if args.resume:
        ckpt = torch.load(args.resume, map_location=device, weights_only=False)
        model.load_state_dict(ckpt["model"])
        optimizer.load_state_dict(ckpt["optimizer"])
        scheduler.load_state_dict(ckpt["scheduler"])
        start_epoch = ckpt["epoch"] + 1
        global_step = ckpt["global_step"]
        print(f"Resumed from epoch {start_epoch}, step {global_step}")

    model.train()
    best_loss = float("inf")

    for epoch in range(start_epoch, args.epochs):
        epoch_loss = 0.0
        epoch_pred = 0.0
        epoch_rt = 0.0
        t0 = time.time()

        for batch_idx, batch in enumerate(loader):
            codec_emb = batch["codec_emb"].to(device)
            text_tokens = batch["text_tokens"].to(device)
            codec_mask = batch["codec_mask"].to(device)
            text_mask = batch["text_mask"].to(device)

            # Input noise annealing
            noise_progress = min(1.0, global_step / (total_steps * 0.3))
            model.input_noise = args.input_noise_max * noise_progress

            losses = model(codec_emb, text_tokens, codec_mask, text_mask)

            total_loss = losses["total_loss"]
            pred_loss = losses["prediction_loss"]
            rt_loss = losses["roundtrip_loss"]

            optimizer.zero_grad()
            total_loss.backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
            optimizer.step()
            scheduler.step()
            model.update_ema()

            epoch_loss += total_loss.item()
            epoch_pred += pred_loss.item()
            epoch_rt += rt_loss.item()
            global_step += 1

            if global_step % args.log_every == 0:
                lr = optimizer.param_groups[0]["lr"]
                writer.add_scalar("train/total_loss", total_loss.item(), global_step)
                writer.add_scalar("train/prediction_loss", pred_loss.item(), global_step)
                writer.add_scalar("train/roundtrip_loss", rt_loss.item(), global_step)
                writer.add_scalar("train/grad_norm", grad_norm.item(), global_step)
                writer.add_scalar("train/lr", lr, global_step)

            if global_step % args.print_every == 0:
                elapsed = time.time() - t0
                print(
                    f"  [{epoch+1}/{args.epochs}] step {global_step} | "
                    f"loss={total_loss.item():.4f} pred={pred_loss.item():.4f} "
                    f"rt={rt_loss.item():.4f} | "
                    f"grad={grad_norm.item():.2f} | {elapsed:.1f}s"
                )

        n_batches = len(loader)
        avg_loss = epoch_loss / n_batches
        avg_pred = epoch_pred / n_batches
        avg_rt = epoch_rt / n_batches
        epoch_time = time.time() - t0

        print(
            f"Epoch {epoch+1}/{args.epochs} | "
            f"avg_loss={avg_loss:.4f} pred={avg_pred:.4f} rt={avg_rt:.4f} | {epoch_time:.1f}s"
        )

        if (epoch + 1) % args.save_every == 0 or (epoch + 1) == args.epochs:
            ckpt_path = ckpt_dir / f"epoch_{epoch+1:04d}.pt"
            torch.save({
                "epoch": epoch,
                "global_step": global_step,
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
                "config": config,
                "loss": avg_loss,
            }, ckpt_path)
            print(f"  Saved checkpoint: {ckpt_path}")

        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save({
                "epoch": epoch,
                "global_step": global_step,
                "model": model.state_dict(),
                "config": config,
                "loss": avg_loss,
            }, ckpt_dir / "best.pt")
            print(f"  New best (loss={avg_loss:.4f})")

    writer.close()
    print(f"\nTraining complete. Best loss: {best_loss:.4f}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_dir", default="processed_data_codec")
    parser.add_argument("--output_dir", default="output_v5")
    parser.add_argument("--d_model", type=int, default=256)
    parser.add_argument("--nhead", type=int, default=4)
    parser.add_argument("--text_layers", type=int, default=4)
    parser.add_argument("--predictor_layers", type=int, default=6)
    parser.add_argument("--dropout", type=float, default=0.1)
    parser.add_argument("--ds_stride", type=int, default=3)
    parser.add_argument("--pred_weight", type=float, default=10.0)
    parser.add_argument("--roundtrip_weight", type=float, default=1.0)
    parser.add_argument("--ema_decay", type=float, default=0.998)
    parser.add_argument("--input_noise_max", type=float, default=0.1)
    parser.add_argument("--max_codec_frames", type=int, default=900)
    parser.add_argument("--epochs", type=int, default=200)
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--grad_clip", type=float, default=1.0)
    parser.add_argument("--num_workers", type=int, default=4)
    parser.add_argument("--log_every", type=int, default=10)
    parser.add_argument("--print_every", type=int, default=50)
    parser.add_argument("--save_every", type=int, default=10)
    parser.add_argument("--resume", default=None)
    args = parser.parse_args()
    train(args)
