"""
Train predictor + text encoder + duration predictor with frozen projections.
Loads proj_in/proj_out from a v5 checkpoint (already good at roundtrip).
Text is expanded to frame level via duration predictor — can't be ignored.
"""

import os
import json
import time
import argparse
import torch
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path

from model_v5 import (
    LeWMTTSv5, build_model, JEPAPredictor, TextEncoder,
    DurationPredictor, compute_uniform_durations, length_regulate,
)
from dataset_v5 import build_dataloader


def train(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

    # Load pretrained checkpoint for projections
    ckpt = torch.load(args.pretrained, map_location=device, weights_only=False)
    config = ckpt["config"]
    config["pred_weight"] = args.pred_weight
    config["dur_weight"] = args.dur_weight
    config["input_noise"] = 0.0

    model = LeWMTTSv5(config).to(device)

    # Load what we can from pretrained (projections, etc.)
    pretrained_state = ckpt["model"]
    model_state = model.state_dict()
    loaded_keys = []
    skipped_keys = []
    for k, v in pretrained_state.items():
        if k in model_state and model_state[k].shape == v.shape:
            model_state[k] = v
            loaded_keys.append(k)
        else:
            skipped_keys.append(k)
    model.load_state_dict(model_state)
    print(f"Loaded {len(loaded_keys)} params from {args.pretrained} (skipped {len(skipped_keys)})")
    if skipped_keys:
        print(f"  Skipped: {skipped_keys[:10]}{'...' if len(skipped_keys) > 10 else ''}")

    # Freeze proj_in and proj_out — they're already good
    for p in model.proj_in.parameters():
        p.requires_grad = False
    for p in model.proj_out.parameters():
        p.requires_grad = False
    for p in model.ema_proj_in.parameters():
        p.requires_grad = False

    # Reinitialize predictor + text encoder + duration predictor from scratch
    if args.reinit:
        print("Reinitializing predictor + text encoder + duration predictor from scratch")
        d_model = config.get("d_model", 256)
        nhead = config.get("nhead", 4)
        predictor_layers = config.get("predictor_layers", 6)
        dropout = config.get("dropout", 0.1)
        model.predictor = JEPAPredictor(
            d_model=d_model, nhead=nhead,
            num_layers=predictor_layers, dropout=dropout,
        ).to(device)
        model.text_encoder = TextEncoder(
            vocab_size=config.get("text_vocab_size", 256),
            d_model=d_model, nhead=nhead,
            num_layers=config.get("text_encoder_layers", 4),
            dropout=dropout,
        ).to(device)
        model.duration_predictor = DurationPredictor(
            d_model=d_model, kernel_size=3, num_layers=2, dropout=dropout,
        ).to(device)
        model.start_emb = torch.nn.Parameter(torch.randn(1, 1, d_model, device=device) * 0.02)

    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Trainable params: {trainable/1e6:.2f}M (predictor + text encoder + dur predictor + start_emb)")

    manifest_path = os.path.join(args.data_dir, "manifest.json")
    loader, dataset = build_dataloader(
        manifest_path, batch_size=args.batch_size,
        num_workers=args.num_workers, max_codec_frames=args.max_codec_frames,
    )
    print(f"Training: {len(dataset)} samples, {len(loader)} batches/epoch")

    # Only optimize trainable params
    optimizer = torch.optim.AdamW(
        [p for p in model.parameters() if p.requires_grad],
        lr=args.lr, betas=(0.9, 0.98), weight_decay=0.01,
    )

    total_steps = args.epochs * len(loader)
    warmup_steps = min(2000, total_steps // 10)

    def lr_lambda(step):
        if step < warmup_steps:
            return step / max(1, warmup_steps)
        progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
        return 0.5 * (1 + torch.cos(torch.tensor(progress * 3.14159)).item())

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    ckpt_dir = Path(args.output_dir) / "checkpoints"
    ckpt_dir.mkdir(parents=True, exist_ok=True)
    writer = SummaryWriter(Path(args.output_dir) / "logs")

    with open(Path(args.output_dir) / "config.json", "w") as f:
        json.dump(config, f, indent=2)

    model.train()
    model.proj_in.eval()
    model.proj_out.eval()
    best_loss = float("inf")
    global_step = 0

    for epoch in range(args.epochs):
        epoch_loss = 0.0
        epoch_pred = 0.0
        epoch_dur = 0.0
        t0 = time.time()

        for batch_idx, batch in enumerate(loader):
            codec_emb = batch["codec_emb"].to(device)
            text_tokens = batch["text_tokens"].to(device)
            codec_mask = batch["codec_mask"].to(device)
            text_mask = batch["text_mask"].to(device)

            B = codec_emb.shape[0]
            T = codec_emb.shape[2]
            codec_seq = codec_emb.transpose(1, 2)  # [B, T, 128]

            # Text encoding (trainable)
            text_emb = model.text_encoder(text_tokens, text_mask)

            # Duration prediction
            log_dur_pred = model.duration_predictor(text_emb, text_mask)

            # Ground truth durations (uniform)
            text_lengths = (~text_mask).sum(dim=1)
            audio_lengths = (~codec_mask).sum(dim=1)
            gt_durations = compute_uniform_durations(text_lengths, audio_lengths)
            log_dur_gt = torch.log(gt_durations.float().clamp(min=1))

            # Duration loss
            valid_text = (~text_mask).float()
            dur_loss = F.mse_loss(
                log_dur_pred * valid_text, log_dur_gt * valid_text, reduction="sum"
            ) / (valid_text.sum() + 1e-8)

            # Expand text to frame level using GT durations
            text_emb_expanded = length_regulate(text_emb, gt_durations, text_mask)
            if text_emb_expanded.shape[1] > T:
                text_emb_expanded = text_emb_expanded[:, :T]
            elif text_emb_expanded.shape[1] < T:
                pad = torch.zeros(B, T - text_emb_expanded.shape[1], text_emb.shape[-1],
                                device=device)
                text_emb_expanded = torch.cat([text_emb_expanded, pad], dim=1)

            # Project with FROZEN proj_in — targets are stable
            with torch.no_grad():
                z = model.proj_in(codec_seq)  # [B, T, d_model]
                target_emb = z.clone()

            # Prepend start, shift right
            start = model.start_emb.expand(B, -1, -1)
            input_emb = torch.cat([start, z[:, :-1]], dim=1)

            # Masks
            start_mask = torch.zeros(B, 1, dtype=torch.bool, device=device)
            pred_mask = torch.cat([start_mask, codec_mask[:, :-1]], dim=1)

            # Input noise
            noise_progress = min(1.0, global_step / (total_steps * 0.3))
            noise_level = args.input_noise_max * noise_progress
            if noise_level > 0:
                noise = torch.zeros_like(input_emb)
                noise[:, 1:] = torch.randn(B, T - 1, input_emb.shape[-1],
                                            device=device) * noise_level
                input_emb = input_emb + noise

            # Frame masking — zero out audio frames to force text reliance
            if args.frame_mask_ratio > 0:
                frame_mask = torch.zeros(B, T, 1, device=device)
                valid_lens = (~pred_mask).sum(dim=1)  # [B]
                for b in range(B):
                    vlen = valid_lens[b].item()
                    if vlen < 10:
                        continue
                    # Mask contiguous blocks totaling ~frame_mask_ratio of valid frames
                    target_masked = int(vlen * args.frame_mask_ratio)
                    masked_so_far = 0
                    while masked_so_far < target_masked:
                        # Random block length 15-50 frames (~0.2-0.7s)
                        block_len = min(torch.randint(15, 51, (1,)).item(),
                                       target_masked - masked_so_far)
                        # Random start position (skip pos 0 = start_emb)
                        start_pos = torch.randint(1, max(2, vlen - block_len), (1,)).item()
                        frame_mask[b, start_pos:start_pos + block_len] = 1.0
                        masked_so_far += block_len
                input_emb = input_emb * (1.0 - frame_mask)

            # Predict with frame-level text
            predicted = model.predictor(input_emb, text_emb_expanded, pred_mask)

            # Prediction loss (MSE vs frozen proj_in targets)
            valid = (~codec_mask).unsqueeze(-1)
            pred_loss = F.mse_loss(
                predicted * valid, target_emb * valid, reduction="sum"
            ) / (valid.sum() * predicted.shape[-1] + 1e-8)

            total_loss = args.pred_weight * pred_loss + args.dur_weight * dur_loss

            optimizer.zero_grad()
            total_loss.backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(
                [p for p in model.parameters() if p.requires_grad], args.grad_clip
            )
            optimizer.step()
            scheduler.step()

            epoch_loss += pred_loss.item()
            epoch_pred += pred_loss.item()
            epoch_dur += dur_loss.item()
            global_step += 1

            if global_step % args.log_every == 0:
                writer.add_scalar("train/pred_loss", pred_loss.item(), global_step)
                writer.add_scalar("train/dur_loss", dur_loss.item(), global_step)
                writer.add_scalar("train/total_loss", total_loss.item(), global_step)
                writer.add_scalar("train/grad_norm", grad_norm.item(), global_step)
                writer.add_scalar("train/lr", optimizer.param_groups[0]["lr"], global_step)

            if global_step % args.print_every == 0:
                elapsed = time.time() - t0
                print(
                    f"  [{epoch+1}/{args.epochs}] step {global_step} | "
                    f"pred={pred_loss.item():.4f} dur={dur_loss.item():.4f} | "
                    f"grad={grad_norm.item():.2f} | {elapsed:.1f}s"
                )

        avg_loss = epoch_loss / len(loader)
        avg_dur = epoch_dur / len(loader)
        epoch_time = time.time() - t0
        print(f"Epoch {epoch+1}/{args.epochs} | pred={avg_loss:.4f} dur={avg_dur:.4f} | {epoch_time:.1f}s")

        if (epoch + 1) % args.save_every == 0 or (epoch + 1) == args.epochs:
            ckpt_path = ckpt_dir / f"epoch_{epoch+1:04d}.pt"
            torch.save({
                "epoch": epoch,
                "global_step": global_step,
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
                "config": config,
                "loss": avg_loss,
            }, ckpt_path)
            print(f"  Saved: {ckpt_path}")

        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save({
                "epoch": epoch,
                "global_step": global_step,
                "model": model.state_dict(),
                "config": config,
                "loss": avg_loss,
            }, ckpt_dir / "best.pt")
            print(f"  New best (pred={avg_loss:.4f})")

    writer.close()
    print(f"\nDone. Best pred loss: {best_loss:.4f}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--pretrained", default="output_v5/checkpoints/epoch_0120.pt")
    parser.add_argument("--data_dir", default="processed_data_codec")
    parser.add_argument("--output_dir", default="output_v5_pred")
    parser.add_argument("--pred_weight", type=float, default=1.0)
    parser.add_argument("--dur_weight", type=float, default=1.0)
    parser.add_argument("--input_noise_max", type=float, default=0.05)
    parser.add_argument("--frame_mask_ratio", type=float, default=0.0,
                        help="Fraction of input audio frames to zero out (forces text reliance)")
    parser.add_argument("--max_codec_frames", type=int, default=900)
    parser.add_argument("--epochs", type=int, default=200)
    parser.add_argument("--batch_size", type=int, default=256)
    parser.add_argument("--lr", type=float, default=5e-5)
    parser.add_argument("--grad_clip", type=float, default=1.0)
    parser.add_argument("--num_workers", type=int, default=4)
    parser.add_argument("--log_every", type=int, default=10)
    parser.add_argument("--print_every", type=int, default=20)
    parser.add_argument("--save_every", type=int, default=10)
    parser.add_argument("--reinit", action="store_true", help="Reinitialize predictor from scratch")
    args = parser.parse_args()
    train(args)
