"""Training for LeWM TTS v7 — AR + NAR on 8-level EnCodec tokens."""
import os, json, time, argparse, torch
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path
from model_v7 import build_model_v7
from dataset_v7 import build_dataloader


def train(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    config = {
        "d_model": args.d_model, "nhead": args.nhead, "n_codes": 1024, "n_rvq": 8,
        "text_vocab_size": 256, "text_encoder_layers": args.text_layers,
        "ar_layers": args.ar_layers, "nar_layers": args.nar_layers,
        "dropout": args.dropout, "label_smoothing": args.label_smoothing,
        "n_speakers": args.n_speakers,
    }
    model, config = build_model_v7(config)
    model = model.to(device)

    loader, dataset = build_dataloader(
        os.path.join(args.data_dir, "manifest.json"),
        batch_size=args.batch_size, num_workers=4)
    print(f"Training: {len(dataset)} samples, {len(loader)} batches/epoch")

    opt = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.98), weight_decay=0.01)
    total_steps = args.epochs * len(loader)
    warmup = min(4000, total_steps // 10)

    def lr_fn(s):
        if s < warmup: return s / max(1, warmup)
        return 0.5 * (1 + torch.cos(torch.tensor((s - warmup) / max(1, total_steps - warmup) * 3.14159)).item())

    sched = torch.optim.lr_scheduler.LambdaLR(opt, lr_fn)
    ckpt_dir = Path(args.output_dir) / "checkpoints"
    ckpt_dir.mkdir(parents=True, exist_ok=True)
    writer = SummaryWriter(Path(args.output_dir) / "logs")
    with open(Path(args.output_dir) / "config.json", "w") as f:
        json.dump(config, f, indent=2)

    start_epoch, global_step, best_loss = 0, 0, float("inf")
    if args.resume:
        ckpt = torch.load(args.resume, map_location=device, weights_only=False)
        model.load_state_dict(ckpt["model"]); opt.load_state_dict(ckpt["optimizer"])
        sched.load_state_dict(ckpt["scheduler"])
        start_epoch, global_step = ckpt["epoch"] + 1, ckpt["global_step"]

    model.train()
    for epoch in range(start_epoch, args.epochs):
        ep = {k: 0.0 for k in ["loss", "ar", "nar", "ar_acc", "nar_acc"]}
        t0 = time.time()
        for batch in loader:
            tok = batch["tokens"].to(device)
            text = batch["text_tokens"].to(device)
            tok_mask = batch["token_mask"].to(device)
            text_mask = batch["text_mask"].to(device)
            spk = batch.get("speaker_id")
            if spk is not None: spk = spk.to(device)

            losses = model(tok, text, tok_mask, text_mask, speaker_id=spk)
            opt.zero_grad()
            losses["total_loss"].backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
            opt.step(); sched.step()

            ep["loss"] += losses["total_loss"].item()
            ep["ar"] += losses["ar_loss"].item()
            ep["nar"] += losses["nar_loss"].item()
            ep["ar_acc"] += losses["ar_acc"].item()
            ep["nar_acc"] += losses["nar_acc"].item()
            global_step += 1

            if global_step % args.log_every == 0:
                for k, v in losses.items():
                    writer.add_scalar(f"train/{k}", v.item(), global_step)

            if global_step % args.print_every == 0:
                print(f"  [{epoch+1}/{args.epochs}] step {global_step} | "
                      f"ar={losses['ar_loss'].item():.3f} nar={losses['nar_loss'].item():.3f} "
                      f"ar_acc={losses['ar_acc'].item():.3f} nar_acc={losses['nar_acc'].item():.3f} | "
                      f"{time.time()-t0:.1f}s")

        n = len(loader)
        avg_loss = ep["loss"] / n
        print(f"Epoch {epoch+1}/{args.epochs} | loss={avg_loss:.4f} "
              f"ar={ep['ar']/n:.3f} nar={ep['nar']/n:.3f} "
              f"ar_acc={ep['ar_acc']/n:.3f} nar_acc={ep['nar_acc']/n:.3f} | "
              f"{time.time()-t0:.1f}s")

        if (epoch + 1) % args.save_every == 0 or (epoch + 1) == args.epochs:
            torch.save({"epoch": epoch, "global_step": global_step,
                         "model": model.state_dict(), "optimizer": opt.state_dict(),
                         "scheduler": sched.state_dict(), "config": config, "loss": avg_loss},
                        ckpt_dir / f"epoch_{epoch+1:04d}.pt")

        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save({"epoch": epoch, "global_step": global_step,
                         "model": model.state_dict(), "config": config, "loss": best_loss},
                        ckpt_dir / "best.pt")
            print(f"  New best (loss={best_loss:.4f})")

    writer.close()
    print(f"\nDone. Best: {best_loss:.4f}")


if __name__ == "__main__":
    p = argparse.ArgumentParser()
    p.add_argument("--data_dir", default="/home/ubuntu/lewm-tts/processed_data_v7")
    p.add_argument("--output_dir", default="/home/ubuntu/lewm-tts/output_v7")
    p.add_argument("--d_model", type=int, default=256)
    p.add_argument("--nhead", type=int, default=4)
    p.add_argument("--text_layers", type=int, default=4)
    p.add_argument("--ar_layers", type=int, default=8)
    p.add_argument("--nar_layers", type=int, default=6)
    p.add_argument("--dropout", type=float, default=0.1)
    p.add_argument("--label_smoothing", type=float, default=0.1)
    p.add_argument("--n_speakers", type=int, default=1)
    p.add_argument("--epochs", type=int, default=200)
    p.add_argument("--batch_size", type=int, default=16)
    p.add_argument("--lr", type=float, default=3e-4)
    p.add_argument("--grad_clip", type=float, default=1.0)
    p.add_argument("--log_every", type=int, default=10)
    p.add_argument("--print_every", type=int, default=100)
    p.add_argument("--save_every", type=int, default=25)
    p.add_argument("--resume", default=None)
    train(p.parse_args())
