"""Training for LeWM TTS v6 — EnCodec token prediction."""

import os, json, time, argparse, torch
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path
from model_v6 import build_model_v6
from dataset_v6 import build_dataloader


def train(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

    config = {
        "d_model": args.d_model, "nhead": args.nhead, "n_codes": 1024,
        "text_vocab_size": 256, "text_encoder_layers": args.text_layers,
        "predictor_layers": args.predictor_layers,
        "dropout": args.dropout, "label_smoothing": args.label_smoothing,
    }
    model, config = build_model_v6(config)
    model = model.to(device)

    manifest = os.path.join(args.data_dir, "manifest_tokens.json")
    loader, dataset = build_dataloader(manifest, batch_size=args.batch_size, num_workers=4)
    print(f"Training: {len(dataset)} samples, {len(loader)} batches/epoch")

    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.98), weight_decay=0.01)
    total_steps = args.epochs * len(loader)
    warmup = min(4000, total_steps // 10)

    def lr_fn(step):
        if step < warmup:
            return step / max(1, warmup)
        p = (step - warmup) / max(1, total_steps - warmup)
        return 0.5 * (1 + torch.cos(torch.tensor(p * 3.14159)).item())

    sched = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_fn)

    ckpt_dir = Path(args.output_dir) / "checkpoints"
    ckpt_dir.mkdir(parents=True, exist_ok=True)
    writer = SummaryWriter(Path(args.output_dir) / "logs")
    with open(Path(args.output_dir) / "config.json", "w") as f:
        json.dump(config, f, indent=2)

    start_epoch = 0
    global_step = 0
    if args.resume:
        ckpt = torch.load(args.resume, map_location=device, weights_only=False)
        model.load_state_dict(ckpt["model"])
        optimizer.load_state_dict(ckpt["optimizer"])
        sched.load_state_dict(ckpt["scheduler"])
        start_epoch = ckpt["epoch"] + 1
        global_step = ckpt["global_step"]
        print(f"Resumed from epoch {start_epoch}")

    model.train()
    best_loss = float("inf")

    for epoch in range(start_epoch, args.epochs):
        ep_loss = ep_acc = 0.0
        t0 = time.time()

        for batch in loader:
            tokens = batch["tokens"].to(device)
            text = batch["text_tokens"].to(device)
            tok_mask = batch["token_mask"].to(device)
            text_mask = batch["text_mask"].to(device)

            losses = model(tokens, text, tok_mask, text_mask)
            loss = losses["total_loss"]

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
            optimizer.step()
            sched.step()

            ep_loss += loss.item()
            ep_acc += losses["token_accuracy"].item()
            global_step += 1

            if global_step % args.log_every == 0:
                for k, v in losses.items():
                    writer.add_scalar(f"train/{k}", v.item(), global_step)

            if global_step % args.print_every == 0:
                print(f"  [{epoch+1}/{args.epochs}] step {global_step} | "
                      f"loss={loss.item():.4f} acc={losses['token_accuracy'].item():.3f} | "
                      f"{time.time()-t0:.1f}s")

        n = len(loader)
        avg_loss = ep_loss / n
        avg_acc = ep_acc / n
        print(f"Epoch {epoch+1}/{args.epochs} | loss={avg_loss:.4f} acc={avg_acc:.3f} | {time.time()-t0:.1f}s")

        if (epoch + 1) % args.save_every == 0 or (epoch + 1) == args.epochs:
            torch.save({
                "epoch": epoch, "global_step": global_step,
                "model": model.state_dict(), "optimizer": optimizer.state_dict(),
                "scheduler": sched.state_dict(), "config": config, "loss": avg_loss,
            }, ckpt_dir / f"epoch_{epoch+1:04d}.pt")

        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save({"epoch": epoch, "global_step": global_step,
                         "model": model.state_dict(), "config": config, "loss": best_loss},
                        ckpt_dir / "best.pt")
            print(f"  New best (loss={best_loss:.4f}, acc={avg_acc:.3f})")

    writer.close()
    print(f"\nDone. Best loss: {best_loss:.4f}")


if __name__ == "__main__":
    p = argparse.ArgumentParser()
    p.add_argument("--data_dir", default="/home/ubuntu/lewm-tts/processed_data_codec")
    p.add_argument("--output_dir", default="/home/ubuntu/lewm-tts/output_v6")
    p.add_argument("--d_model", type=int, default=256)
    p.add_argument("--nhead", type=int, default=4)
    p.add_argument("--text_layers", type=int, default=4)
    p.add_argument("--predictor_layers", type=int, default=8)
    p.add_argument("--dropout", type=float, default=0.1)
    p.add_argument("--label_smoothing", type=float, default=0.1)
    p.add_argument("--epochs", type=int, default=200)
    p.add_argument("--batch_size", type=int, default=16)
    p.add_argument("--lr", type=float, default=3e-4)
    p.add_argument("--grad_clip", type=float, default=1.0)
    p.add_argument("--log_every", type=int, default=10)
    p.add_argument("--print_every", type=int, default=100)
    p.add_argument("--save_every", type=int, default=25)
    p.add_argument("--resume", default=None)
    train(p.parse_args())
