"""Training loop for LeWM TTS VQ model."""

import os
import json
import time
import argparse
import torch
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path

from model_vq import build_model_vq
from dataset import build_dataloader


def train(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

    config = {
        "d_model": args.d_model,
        "nhead": args.nhead,
        "n_mels": 100,
        "text_vocab_size": 256,
        "text_encoder_layers": args.text_layers,
        "audio_encoder_layers": args.audio_layers,
        "predictor_layers": args.predictor_layers,
        "n_codes": args.n_codes,
        "dropout": args.dropout,
        "pred_weight": args.pred_weight,
        "recon_weight": args.recon_weight,
        "spectral_weight": args.spectral_weight,
        "commitment_weight": args.commitment_weight,
        "label_smoothing": args.label_smoothing,
        "downsample_factor": args.downsample_factor,
    }
    model, config = build_model_vq(config)
    model = model.to(device)

    manifest_path = os.path.join(args.data_dir, "manifest.json")
    loader, dataset = build_dataloader(
        manifest_path, batch_size=args.batch_size, num_workers=args.num_workers,
    )
    print(f"Training: {len(dataset)} samples, {len(loader)} batches/epoch")

    optimizer = torch.optim.AdamW(
        model.parameters(), lr=args.lr, betas=(0.9, 0.98), weight_decay=0.01,
    )

    total_steps = args.epochs * len(loader)
    warmup_steps = min(4000, total_steps // 10)

    def lr_lambda(step):
        if step < warmup_steps:
            return step / max(1, warmup_steps)
        progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
        return 0.5 * (1 + torch.cos(torch.tensor(progress * 3.14159)).item())

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    ckpt_dir = Path(args.output_dir) / "checkpoints"
    ckpt_dir.mkdir(parents=True, exist_ok=True)
    writer = SummaryWriter(Path(args.output_dir) / "logs")

    with open(Path(args.output_dir) / "config.json", "w") as f:
        json.dump(config, f, indent=2)

    start_epoch = 0
    global_step = 0
    if args.resume:
        ckpt = torch.load(args.resume, map_location=device, weights_only=False)
        model.load_state_dict(ckpt["model"])
        optimizer.load_state_dict(ckpt["optimizer"])
        scheduler.load_state_dict(ckpt["scheduler"])
        start_epoch = ckpt["epoch"] + 1
        global_step = ckpt["global_step"]
        print(f"Resumed from epoch {start_epoch}, step {global_step}")

    model.train()
    best_loss = float("inf")

    for epoch in range(start_epoch, args.epochs):
        ep_loss = ep_tok = ep_commit = ep_recon = ep_spec = ep_acc = 0.0
        t0 = time.time()

        for batch_idx, batch in enumerate(loader):
            mel = batch["mel"].to(device)
            text_tokens = batch["text_tokens"].to(device)
            mel_mask = batch["mel_mask"].to(device)
            text_mask = batch["text_mask"].to(device)

            losses = model(mel, text_tokens, mel_mask, text_mask)

            total_loss = losses["total_loss"]

            optimizer.zero_grad()
            total_loss.backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
            optimizer.step()
            scheduler.step()

            ep_loss += total_loss.item()
            ep_tok += losses["token_loss"].item()
            ep_commit += losses["commit_loss"].item()
            ep_recon += losses["recon_loss"].item()
            ep_spec += losses["spectral_loss"].item()
            ep_acc += losses["token_accuracy"].item()
            global_step += 1

            if global_step % args.log_every == 0:
                lr = optimizer.param_groups[0]["lr"]
                for k, v in losses.items():
                    writer.add_scalar(f"train/{k}", v.item(), global_step)
                writer.add_scalar("train/grad_norm", grad_norm.item(), global_step)
                writer.add_scalar("train/lr", lr, global_step)

            if global_step % args.print_every == 0:
                print(
                    f"  [{epoch+1}/{args.epochs}] step {global_step} | "
                    f"loss={total_loss.item():.4f} tok={losses['token_loss'].item():.4f} "
                    f"commit={losses['commit_loss'].item():.4f} "
                    f"recon={losses['recon_loss'].item():.4f} "
                    f"spec={losses['spectral_loss'].item():.4f} "
                    f"acc={losses['token_accuracy'].item():.3f} | "
                    f"grad={grad_norm.item():.2f} | {time.time()-t0:.1f}s"
                )

        n = len(loader)
        epoch_time = time.time() - t0
        print(
            f"Epoch {epoch+1}/{args.epochs} | "
            f"loss={ep_loss/n:.4f} tok={ep_tok/n:.4f} commit={ep_commit/n:.4f} "
            f"recon={ep_recon/n:.4f} spec={ep_spec/n:.4f} "
            f"acc={ep_acc/n:.3f} | {epoch_time:.1f}s"
        )

        if (epoch + 1) % args.save_every == 0 or (epoch + 1) == args.epochs:
            torch.save({
                "epoch": epoch, "global_step": global_step,
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
                "config": config, "loss": ep_loss / n,
            }, ckpt_dir / f"epoch_{epoch+1:04d}.pt")

        if ep_loss / n < best_loss:
            best_loss = ep_loss / n
            torch.save({
                "epoch": epoch, "global_step": global_step,
                "model": model.state_dict(), "config": config, "loss": best_loss,
            }, ckpt_dir / "best.pt")
            print(f"  New best (loss={best_loss:.4f})")

    writer.close()
    print(f"\nDone. Best loss: {best_loss:.4f}")


if __name__ == "__main__":
    p = argparse.ArgumentParser()
    p.add_argument("--data_dir", default="/home/ubuntu/lewm-tts/processed_data_100mel")
    p.add_argument("--output_dir", default="/home/ubuntu/lewm-tts/output_vq")
    p.add_argument("--d_model", type=int, default=256)
    p.add_argument("--nhead", type=int, default=4)
    p.add_argument("--text_layers", type=int, default=4)
    p.add_argument("--audio_layers", type=int, default=4)
    p.add_argument("--predictor_layers", type=int, default=6)
    p.add_argument("--n_codes", type=int, default=1024)
    p.add_argument("--dropout", type=float, default=0.1)
    p.add_argument("--downsample_factor", type=int, default=4)
    p.add_argument("--pred_weight", type=float, default=1.0)
    p.add_argument("--recon_weight", type=float, default=1.0)
    p.add_argument("--spectral_weight", type=float, default=0.5)
    p.add_argument("--commitment_weight", type=float, default=0.25)
    p.add_argument("--label_smoothing", type=float, default=0.1)
    p.add_argument("--epochs", type=int, default=200)
    p.add_argument("--batch_size", type=int, default=16)
    p.add_argument("--lr", type=float, default=1e-4)
    p.add_argument("--grad_clip", type=float, default=1.0)
    p.add_argument("--num_workers", type=int, default=4)
    p.add_argument("--log_every", type=int, default=10)
    p.add_argument("--print_every", type=int, default=100)
    p.add_argument("--save_every", type=int, default=25)
    p.add_argument("--resume", default=None)
    train(p.parse_args())
