#!/usr/bin/env python3
"""
Finetuning script for Cohere Transcribe (cohere-transcribe-03-2026).

2B Fast-Conformer encoder + Transformer decoder.
Reads pre-computed mel spectrogram shards (output of data preprocessing pipeline).
Supports DDP, FSDP, gradient checkpointing, EMA, dynamic batching.

Usage:
    # Single GPU (debugging)
    python train.py --config config.yaml

    # Multi-GPU single node
    torchrun --nproc_per_node=8 train.py --config config.yaml

    # Multi-node
    torchrun --nnodes=4 --nproc_per_node=8 \
        --rdzv_backend=c10d --rdzv_endpoint=master:29500 \
        train.py --config config.yaml
"""

import os
import sys
import math
import time
import json
import copy
import signal
import argparse
import logging
from pathlib import Path
from dataclasses import dataclass, field
from typing import Optional, Dict, List

import torch
import torch.nn as nn
import torch.distributed as dist
from safetensors.torch import load_file as load_safetensors
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    ShardingStrategy,
    MixedPrecision,
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from transformers import (
    AutoProcessor,
    AutoModelForSpeechSeq2Seq,
)
import numpy as np

from dataset import DataConfig, DynamicBatchSampler, create_dataloader
from dataset_fast import FastDataConfig, create_fast_dataloader
from tokenizer_utils import load_extended_tokenizer

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s [%(levelname)s] %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S',
)
logger = logging.getLogger(__name__)


# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------

@dataclass
class TrainConfig:
    # Model
    model_name: str = "CohereLabs/cohere-transcribe-03-2026"
    trust_remote_code: bool = True

    # Data
    data: DataConfig = field(default_factory=DataConfig)
    num_workers: int = 8

    # Training
    epochs: int = 2
    learning_rate: float = 5e-5
    min_lr: float = 1e-6
    warmup_steps: int = 0              # if 0, computed from warmup_ratio
    warmup_ratio: float = 0.02          # 2% of total steps
    weight_decay: float = 0.01
    gradient_clip: float = 1.0
    gradient_accumulation_steps: int = 4
    label_smoothing: float = 0.1

    # Layer-wise LR decay (LLRD)
    llrd_factor: float = 0.9  # each lower layer gets lr * factor^depth

    # Regularization
    dropout_rate: float = 0.05  # override model's dropout during finetuning
    use_spec_augment: bool = True  # SpecAugment on mel inputs during training

    # EMA
    use_ema: bool = True
    ema_decay: float = 0.999

    # Precision
    bf16: bool = True

    # Distributed
    strategy: str = "ddp"       # "ddp" or "fsdp"
    gradient_checkpointing: bool = True

    # Checkpointing
    output_dir: str = "./checkpoints"
    save_every_steps: int = 5000
    eval_every_steps: int = 10000
    log_every_steps: int = 50
    max_checkpoints: int = 5

    # Wandb
    wandb_project: str = "cohere-transcribe-finetune"
    wandb_run_name: Optional[str] = None

    # Max steps (0 = train for full epochs)
    max_steps: int = 0

    # Resume
    resume_from: Optional[str] = None


# ---------------------------------------------------------------------------
# EMA
# ---------------------------------------------------------------------------

class EMAModel:
    """Exponential Moving Average of model parameters."""

    def __init__(self, model: nn.Module, decay: float = 0.999):
        self.decay = decay
        self.shadow = {}
        self.backup = {}
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()

    @torch.no_grad()
    def update(self, model: nn.Module):
        for name, param in model.named_parameters():
            if param.requires_grad and name in self.shadow:
                self.shadow[name].mul_(self.decay).add_(param.data, alpha=1 - self.decay)

    def apply_shadow(self, model: nn.Module):
        """Swap model params with EMA params for evaluation."""
        for name, param in model.named_parameters():
            if name in self.shadow:
                self.backup[name] = param.data.clone()
                param.data.copy_(self.shadow[name])

    def restore(self, model: nn.Module):
        """Restore original params after evaluation."""
        for name, param in model.named_parameters():
            if name in self.backup:
                param.data.copy_(self.backup[name])
        self.backup = {}

    def state_dict(self):
        return {'shadow': self.shadow, 'decay': self.decay}

    def load_state_dict(self, state_dict):
        self.shadow = state_dict['shadow']
        self.decay = state_dict['decay']


# ---------------------------------------------------------------------------
# SpecAugment (on pre-computed mels at training time)
# ---------------------------------------------------------------------------

class SpecAugment(nn.Module):
    """SpecAugment applied to mel spectrograms during training."""

    def __init__(
        self,
        freq_mask_param: int = 27,
        time_mask_param: int = 40,
        num_freq_masks: int = 2,
        num_time_masks: int = 2,
    ):
        super().__init__()
        self.freq_mask_param = freq_mask_param
        self.time_mask_param = time_mask_param
        self.num_freq_masks = num_freq_masks
        self.num_time_masks = num_time_masks

    def forward(self, mel: torch.Tensor, mel_lengths: torch.Tensor) -> torch.Tensor:
        """
        Args:
            mel: [B, n_mels, T] float tensor
            mel_lengths: [B] actual lengths (for adaptive time masking)
        Returns:
            Augmented mel tensor (in-place modification)
        """
        if not self.training:
            return mel

        B, n_mels, T = mel.shape

        for b in range(B):
            length = mel_lengths[b].item()

            # Frequency masks
            for _ in range(self.num_freq_masks):
                f = random.randint(0, self.freq_mask_param)
                f0 = random.randint(0, max(0, n_mels - f))
                mel[b, f0:f0 + f, :length] = 0.0

            # Time masks (adaptive to utterance length)
            max_t = min(self.time_mask_param, int(length * 0.05))
            for _ in range(self.num_time_masks):
                t = random.randint(0, max(0, max_t))
                t0 = random.randint(0, max(0, length - t))
                mel[b, :, t0:t0 + t] = 0.0

        return mel


import random


# ---------------------------------------------------------------------------
# Layer-wise LR Decay
# ---------------------------------------------------------------------------

def get_parameter_groups_with_llrd(model, base_lr, llrd_factor, weight_decay):
    """
    Assign per-layer learning rates with layer-wise LR decay.
    Lower encoder layers get smaller LRs to preserve pretrained features.
    Decoder gets full LR.
    """
    param_groups = []
    no_decay = {'bias', 'layer_norm', 'layernorm'}

    # Try to detect encoder layers and assign depth
    encoder_layer_names = []
    decoder_layer_names = []

    for name, _ in model.named_parameters():
        name_lower = name.lower()
        if 'encoder' in name_lower:
            # Extract layer number if possible
            encoder_layer_names.append(name)
        elif 'decoder' in name_lower:
            decoder_layer_names.append(name)

    # Find max encoder depth
    import re
    max_encoder_depth = 0
    for name in encoder_layer_names:
        matches = re.findall(r'layers?[._](\d+)', name)
        if matches:
            depth = max(int(m) for m in matches)
            max_encoder_depth = max(max_encoder_depth, depth)

    if max_encoder_depth == 0:
        max_encoder_depth = 1  # fallback

    seen = set()
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue
        if name in seen:
            continue
        seen.add(name)

        # Determine LR multiplier based on layer depth
        name_lower = name.lower()
        if 'encoder' in name_lower:
            matches = re.findall(r'layers?[._](\d+)', name)
            if matches:
                depth = max(int(m) for m in matches)
            else:
                depth = max_encoder_depth  # non-layer encoder params get full encoder LR
            # LLRD: deeper layers get higher LR
            lr_mult = llrd_factor ** (max_encoder_depth - depth)
            lr = base_lr * lr_mult
        else:
            # Decoder and other params get full LR
            lr = base_lr

        # Weight decay or not
        wd = 0.0 if any(nd in name_lower for nd in no_decay) else weight_decay

        param_groups.append({
            'params': [param],
            'lr': lr,
            'weight_decay': wd,
            'name': name,
        })

    return param_groups


# ---------------------------------------------------------------------------
# LR Schedule: Linear warmup + Cosine decay
# ---------------------------------------------------------------------------

def get_lr(step, warmup_steps, total_steps, base_lr, min_lr):
    if step < warmup_steps:
        return base_lr * step / max(1, warmup_steps)
    progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
    return min_lr + 0.5 * (base_lr - min_lr) * (1 + math.cos(math.pi * progress))


def update_lr(optimizer, step, warmup_steps, total_steps, base_lr, min_lr, llrd_factor, max_encoder_depth):
    """Update LR for all param groups, respecting LLRD multipliers."""
    current_base_lr = get_lr(step, warmup_steps, total_steps, base_lr, min_lr)
    for group in optimizer.param_groups:
        name = group.get('name', '')
        if 'encoder' in name.lower():
            import re
            matches = re.findall(r'layers?[._](\d+)', name)
            if matches:
                depth = max(int(m) for m in matches)
            else:
                depth = max_encoder_depth
            lr_mult = llrd_factor ** (max_encoder_depth - depth)
            group['lr'] = current_base_lr * lr_mult
        else:
            group['lr'] = current_base_lr
    return current_base_lr


# ---------------------------------------------------------------------------
# Training Loop
# ---------------------------------------------------------------------------

class Trainer:
    def __init__(self, config: TrainConfig):
        self.config = config
        self.setup_distributed()
        self.setup_model()
        self.setup_data()
        self.setup_optimizer()
        self.setup_ema()
        self.setup_logging()

        self.global_step = 0
        self.epoch = 0
        self.best_wer = float('inf')

        # SpecAugment
        if config.use_spec_augment:
            self.spec_augment = SpecAugment()
        else:
            self.spec_augment = None

        # Graceful shutdown handler
        self._shutdown = False
        signal.signal(signal.SIGTERM, self._handle_sigterm)

        if config.resume_from:
            self.load_checkpoint(config.resume_from)

    def _handle_sigterm(self, signum, frame):
        logger.info("SIGTERM received, saving checkpoint and shutting down...")
        self._shutdown = True

    def setup_distributed(self):
        """Initialize distributed training."""
        if 'RANK' in os.environ:
            dist.init_process_group(backend='nccl')
            self.rank = dist.get_rank()
            self.world_size = dist.get_world_size()
            self.local_rank = int(os.environ.get('LOCAL_RANK', 0))
            torch.cuda.set_device(self.local_rank)
            self.device = torch.device(f'cuda:{self.local_rank}')
            self.is_main = self.rank == 0
        else:
            self.rank = 0
            self.world_size = 1
            self.local_rank = 0
            self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
            self.is_main = True

        if self.is_main:
            logger.info(f"World size: {self.world_size}, Device: {self.device}")

    def setup_model(self):
        """Load Cohere Transcribe model."""
        config = self.config

        if self.is_main:
            logger.info(f"Loading model: {config.model_name}")

        self.processor = AutoProcessor.from_pretrained(
            config.model_name,
            trust_remote_code=config.trust_remote_code,
        )
        self.processor.tokenizer = load_extended_tokenizer(config.model_name)

        model = AutoModelForSpeechSeq2Seq.from_pretrained(
            config.model_name,
            trust_remote_code=config.trust_remote_code,
            dtype=torch.bfloat16 if config.bf16 else torch.float32,
        )

        # Enable gradient checkpointing if the model supports it
        if config.gradient_checkpointing:
            if getattr(model, "supports_gradient_checkpointing", False):
                model.gradient_checkpointing_enable()
                if self.is_main:
                    logger.info("Gradient checkpointing enabled")
            else:
                # Model doesn't support HF gradient_checkpointing_enable(), but we can
                # manually wrap encoder layers with torch.utils.checkpoint
                self._enable_manual_grad_checkpointing(model)
                if self.is_main:
                    logger.info("Manual gradient checkpointing enabled for encoder layers")

        # Freeze feature extractor / conv frontend only
        # The rest (all encoder + decoder layers) are unfrozen for language extension
        frozen_count = 0
        for name, param in model.named_parameters():
            name_lower = name.lower()
            # Freeze only the conv subsampling / feature extractor frontend
            if any(k in name_lower for k in ['conv_subsample', 'feature_extractor', 'embed_positions']):
                param.requires_grad = False
                frozen_count += 1

        trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
        total = sum(p.numel() for p in model.parameters())
        if self.is_main:
            logger.info(f"Parameters: {total/1e9:.2f}B total, {trainable/1e9:.2f}B trainable, {frozen_count} frozen groups")

        # ── Encoder SDPA: deferred (relative position bias incompatible with naive SDPA) ──
        self._patch_encoder_attention(model)

        # ── torch.compile: DISABLED ──
        # The model's _setup_compile() compiles each ConformerLayer with dynamic=True,
        # but torch._inductor has a codegen bug (undefined 's5' variable) with the
        # variable-length RelPositionMultiHeadAttention in this conformer architecture.
        # Disabled until upstream torch fixes this. Training still works at full speed
        # without compile — the bottleneck is the encoder attention, not small ops.
        # To re-enable when fixed: model._setup_compile(processor=self.processor)

        # Wrap with DDP or FSDP
        model = model.to(self.device)

        if self.world_size > 1:
            if config.strategy == "fsdp":
                # FSDP with SHARD_GRAD_OP (like ZeRO-2)
                bf16_policy = MixedPrecision(
                    param_dtype=torch.bfloat16,
                    reduce_dtype=torch.bfloat16,
                    buffer_dtype=torch.bfloat16,
                ) if config.bf16 else None

                self.model = FSDP(
                    model,
                    sharding_strategy=ShardingStrategy.SHARD_GRAD_OP,
                    mixed_precision=bf16_policy,
                    device_id=self.local_rank,
                    use_orig_params=True,  # needed for torch.compile
                )
            else:
                # Standard DDP
                self.model = DDP(
                    model,
                    device_ids=[self.local_rank],
                    find_unused_parameters=False,
                )
        else:
            self.model = model

        # Get the raw model for parameter access
        self.raw_model = self.model.module if hasattr(self.model, 'module') else self.model

    def _enable_manual_grad_checkpointing(self, model):
        """Manually wrap encoder ConformerLayer calls with torch.utils.checkpoint.

        The Cohere ASR model sets supports_gradient_checkpointing=False,
        but we can still checkpoint the encoder conformer layers manually.
        This saves ~40% activation memory at the cost of ~30% extra compute.

        We wrap each layer.forward so the encoder's own forward() doesn't need changes.
        """
        from torch.utils.checkpoint import checkpoint

        encoder = None
        for name, module in model.named_modules():
            if hasattr(module, 'layers') and hasattr(module, 'pre_encode'):
                encoder = module
                break

        if encoder is None:
            if self.is_main:
                logger.warning("Could not find encoder for manual grad checkpointing")
            return

        # Wrap each ConformerLayer's forward with checkpoint
        for i, layer in enumerate(encoder.layers):
            original_fwd = layer.forward

            def make_ckpt_fwd(orig):
                def ckpt_fwd(*args, **kwargs):
                    # checkpoint doesn't support kwargs, so we need to handle this
                    # ConformerLayer.forward(x, pos_emb, mask=None, pad_mask=None)
                    return checkpoint(orig, *args, use_reentrant=False, **kwargs)
                return ckpt_fwd

            layer.forward = make_ckpt_fwd(original_fwd)

    def _patch_encoder_attention(self, model):
        """No-op: encoder uses RelPositionMultiHeadAttention with relative position
        bias that combines content and position scores before softmax. Standard SDPA
        cannot handle this without double-counting the content term. The encoder
        attention is left as-is (manual matmul+softmax).

        torch.compile on each ConformerLayer already provides fusion benefits for
        the encoder. A correct SDPA rewrite would need to pass only the position
        bias (matrix_bd) as attn_mask and use unbiased q for the SDPA content term,
        but that changes the attention semantics (pos_bias_u vs pos_bias_v).
        Deferred until we can validate numerical equivalence.
        """
        pass

    def setup_data(self):
        """Setup dataloaders."""
        config = self.config

        # Pass tokenizer to dataset for on-the-fly tokenization
        tokenizer = self.processor.tokenizer
        self.tokenizer = tokenizer

        # Build per-language prompt token IDs for decoder conditioning
        # Prompt format: <|startofcontext|><|startoftranscript|><|emo:undefined|>
        #                <|lang|><|lang|><|pnc|><|noitn|><|notimestamp|><|nodiarize|>
        self.lang_prompt_ids = {}
        for lang in ["en", "hi", "te", "ta", "ml", "bn", "gu", "kn", "pa", "mr", "or", "as"]:
            prompt_str = self.raw_model.build_prompt(language=lang, punctuation=True)
            prompt_ids = tokenizer.encode(prompt_str, add_special_tokens=False)
            self.lang_prompt_ids[lang] = prompt_ids
            if self.is_main and lang == "hi":
                logger.info(f"Decoder prompt ({lang}): {prompt_str} → {len(prompt_ids)} tokens")

        # Fail fast if the tokenizer is splitting decoder control tokens.
        sanity_ids = self.lang_prompt_ids["hi"]
        sanity_tokens = tokenizer.convert_ids_to_tokens(sanity_ids)
        if len(sanity_ids) > 16 or not all(tok in tokenizer.all_special_tokens for tok in sanity_tokens):
            raise RuntimeError(
                "Tokenizer sanity check failed: decoder prompt control tokens are being split. "
                "Expected the extended fast tokenizer from tokenizer.json."
            )

        # The IterableDataset automatically uses extracted .npy files if they exist,
        # falling back to tar streaming otherwise. No separate fast path needed.
        self.train_loader = create_dataloader(
            config.data, tokenizer=tokenizer, split="train",
            num_workers=config.num_workers,
            prebatch=True,
            max_batch_mel_frames=config.data.max_batch_mel_frames,
            max_batch_utterances=config.data.max_batch_utterances,
        )
        self.batch_sampler = DynamicBatchSampler(
            max_batch_mel_frames=config.data.max_batch_mel_frames,
            max_batch_utterances=config.data.max_batch_utterances,
        )

        if self.is_main:
            logger.info(
                f"DataLoader: {config.num_workers} workers, "
                f"max_batch_mel_frames={config.data.max_batch_mel_frames}, "
                f"max_batch_utterances={config.data.max_batch_utterances}"
            )

    def setup_optimizer(self):
        """Setup AdamW with layer-wise LR decay."""
        config = self.config

        # Detect encoder depth
        import re
        self.max_encoder_depth = 0
        for name, _ in self.raw_model.named_parameters():
            if 'encoder' in name.lower():
                matches = re.findall(r'layers?[._](\d+)', name)
                if matches:
                    self.max_encoder_depth = max(self.max_encoder_depth, max(int(m) for m in matches))

        param_groups = get_parameter_groups_with_llrd(
            self.raw_model, config.learning_rate, config.llrd_factor, config.weight_decay
        )

        self.optimizer = torch.optim.AdamW(
            param_groups,
            lr=config.learning_rate,
            betas=(0.9, 0.98),
            eps=1e-9,
        )

        if self.is_main:
            logger.info(f"Optimizer: AdamW, base_lr={config.learning_rate}, "
                        f"LLRD={config.llrd_factor}, max_encoder_depth={self.max_encoder_depth}")

    def setup_ema(self):
        """Setup EMA if enabled."""
        if self.config.use_ema:
            self.ema = EMAModel(self.raw_model, decay=self.config.ema_decay)
            if self.is_main:
                logger.info(f"EMA enabled with decay={self.config.ema_decay}")
        else:
            self.ema = None

    def setup_logging(self):
        """Setup wandb logging."""
        self.wandb = None
        if self.is_main:
            try:
                import wandb
                wandb.init(
                    project=self.config.wandb_project,
                    name=self.config.wandb_run_name,
                    config=vars(self.config),
                )
                self.wandb = wandb
            except Exception as e:
                logger.warning(f"wandb init failed ({e}), logging to stdout only")
                self.wandb = None

            Path(self.config.output_dir).mkdir(parents=True, exist_ok=True)

    def _build_decoder_inputs(self, batch: dict):
        """Build decoder_input_ids and labels with language prompt prefix.

        For each sample: decoder_input_ids = [prompt..., transcript_tokens...]
                         labels            = [-100...,   transcript_tokens..., eos]

        The model does NOT shift labels internally — we must construct
        decoder_input_ids (teacher-forced input) and labels (targets) explicitly.
        """
        transcript_ids = batch['labels']  # [B, max_tokens], padded with -100
        token_lengths = batch['token_lengths']
        languages = batch['languages']
        B = transcript_ids.shape[0]

        # Determine max sequence length after adding prompt
        max_prompt_len = max(len(self.lang_prompt_ids.get(l, self.lang_prompt_ids["en"])) for l in languages)
        max_transcript_len = int(token_lengths.max().item())
        max_seq_len = max_prompt_len + max_transcript_len + 1  # +1 for EOS

        eos_id = self.tokenizer.convert_tokens_to_ids("<|endoftext|>")
        pad_id = self.tokenizer.convert_tokens_to_ids("<pad>")

        decoder_input_ids = torch.full((B, max_seq_len), pad_id, dtype=torch.long)
        labels = torch.full((B, max_seq_len), -100, dtype=torch.long)  # ignore index
        decoder_attention_mask = torch.zeros((B, max_seq_len), dtype=torch.long)

        for i in range(B):
            lang = languages[i]
            prompt = self.lang_prompt_ids.get(lang, self.lang_prompt_ids["en"])
            prompt_len = len(prompt)
            n_tok = int(token_lengths[i].item())

            # Extract valid transcript tokens (skip -100 padding)
            transcript = transcript_ids[i, :n_tok].tolist()

            # The Cohere ASR model does NOT shift labels internally.
            # logits[i] = f(decoder_input_ids[0:i+1]), loss = CE(logits[i], labels[i])
            # So labels[i] must be the NEXT token after position i.
            #
            # decoder_input_ids: [prompt...,        t0,   t1, ..., t_{n-1}]
            # labels:            [-100*(p_len-1),   t0,   t1, ..., t_{n-1}, eos]
            #
            # At pos prompt_len-1: model sees full prompt → predict t0
            # At pos prompt_len:   model sees prompt+t0   → predict t1
            # At pos prompt_len+n-1: model sees prompt+transcript → predict eos

            seq = prompt + transcript
            total_len = len(seq)

            decoder_input_ids[i, :total_len] = torch.tensor(seq, dtype=torch.long)
            decoder_attention_mask[i, :total_len] = 1

            # Labels: shifted left by 1 relative to decoder_input_ids
            # First transcript token label at prompt_len - 1
            for j in range(n_tok):
                labels[i, prompt_len - 1 + j] = transcript[j]
            # EOS label after last transcript token
            labels[i, prompt_len - 1 + n_tok] = eos_id

        return decoder_input_ids, labels, decoder_attention_mask

    def train_step(self, batch: dict) -> dict:
        """Single training step. Returns metrics dict."""
        config = self.config

        # Move to device — non_blocking=True overlaps H2D copy with compute
        mel = batch['mel'].to(self.device, dtype=torch.bfloat16 if config.bf16 else torch.float32, non_blocking=True)
        mel_lengths = batch['mel_lengths'].to(self.device, non_blocking=True)
        token_lengths = batch['token_lengths'].to(self.device, non_blocking=True)

        # Build decoder inputs (CPU work — overlaps with H2D transfer above)
        decoder_input_ids, labels, decoder_attention_mask = self._build_decoder_inputs(batch)
        decoder_input_ids = decoder_input_ids.to(self.device, non_blocking=True)
        labels = labels.to(self.device, non_blocking=True)
        decoder_attention_mask = decoder_attention_mask.to(self.device, non_blocking=True)

        # Sync point — ensure all transfers complete before compute
        torch.cuda.current_stream().synchronize()

        # SpecAugment (on GPU, fast)
        if self.spec_augment and self.model.training:
            mel = self.spec_augment(mel, mel_lengths)

        B, n_mels, T = mel.shape

        # Forward pass — Cohere ASR model requires:
        #   input_features: [B, n_mels, T] mel spectrograms
        #   length: [B] raw mel frame counts (for encoder length inference)
        #   decoder_input_ids (or input_ids): [B, S] decoder token inputs
        #   decoder_attention_mask: [B, S] decoder padding mask
        #   labels: [B, S] targets with -100 for ignored positions
        with torch.amp.autocast('cuda', dtype=torch.bfloat16, enabled=config.bf16):
            # Pass labels=None so the model returns logits without computing its own loss
            # (model's internal CE loss doesn't support label smoothing)
            outputs = self.model(
                input_features=mel,
                length=mel_lengths,
                decoder_input_ids=decoder_input_ids,
                decoder_attention_mask=decoder_attention_mask,
                labels=None,
            )
            logits = outputs.logits  # [B, S, vocab_size]
            # The model head applies log_softmax (config.head.log_softmax=true),
            # so logits are already log-probabilities. Use NLLLoss, NOT CrossEntropyLoss
            # (which would apply log_softmax again, producing wrong gradients).
            loss_fct = nn.NLLLoss(
                ignore_index=-100,
            )
            # For label smoothing with NLLLoss, apply manually:
            # smooth_loss = -(1-ε)*NLL - ε*mean(log_probs)
            log_probs = logits.view(-1, logits.size(-1))
            target = labels.view(-1)
            nll = loss_fct(log_probs, target)
            if config.label_smoothing > 0:
                # Smooth component: uniform distribution over vocab
                smooth = -log_probs.mean(dim=-1)
                # Mask ignored positions
                mask = (target != -100).float()
                smooth = (smooth * mask).sum() / mask.sum()
                loss = (1 - config.label_smoothing) * nll + config.label_smoothing * smooth
            else:
                loss = nll

        # Scale loss for gradient accumulation
        loss = loss / config.gradient_accumulation_steps
        loss.backward()

        return {
            'loss': loss.item() * config.gradient_accumulation_steps,
            'batch_size': B,
            'total_frames': mel_lengths.sum().item(),
            'total_tokens': token_lengths.sum().item(),
        }

    def train_epoch(self, epoch: int, total_steps: int):
        """Train for one epoch."""
        config = self.config
        self.model.train()

        # Set epoch on sampler for proper shuffling (fast dataset)
        if hasattr(self.batch_sampler, 'set_epoch'):
            self.batch_sampler.set_epoch(epoch)

        step_metrics = []
        accum_step = 0
        epoch_start = time.time()

        # Iterate batches from DataLoader (pre-batched by workers or by BucketBatchSampler).
        for batch in self.train_loader:
            if self._shutdown:
                self.save_checkpoint("shutdown")
                return

            metrics = self.train_step(batch)
            step_metrics.append(metrics)
            accum_step += 1

            # Gradient accumulation boundary
            if accum_step % config.gradient_accumulation_steps == 0:
                # Gradient clipping
                if config.strategy == "fsdp":
                    grad_norm = self.model.clip_grad_norm_(config.gradient_clip)
                else:
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        self.raw_model.parameters(), config.gradient_clip
                    )

                # Check for NaN
                if torch.isnan(grad_norm) or torch.isinf(grad_norm):
                    logger.warning(f"Step {self.global_step}: NaN/Inf gradient norm, skipping")
                    self.optimizer.zero_grad(set_to_none=True)
                    step_metrics = []
                    continue

                # Optimizer step
                self.optimizer.step()
                self.optimizer.zero_grad(set_to_none=True)

                # Update LR
                current_lr = update_lr(
                    self.optimizer, self.global_step, config.warmup_steps,
                    total_steps, config.learning_rate, config.min_lr,
                    config.llrd_factor, self.max_encoder_depth
                )

                # EMA update
                if self.ema:
                    self.ema.update(self.raw_model)

                self.global_step += 1

                # Max steps check
                if config.max_steps > 0 and self.global_step >= config.max_steps:
                    if self.is_main:
                        logger.info(f"Reached max_steps={config.max_steps}, stopping training")
                    self.save_checkpoint(f"step-{self.global_step}")
                    return

                # Logging
                if self.is_main and self.global_step % config.log_every_steps == 0:
                    avg_loss = np.mean([m['loss'] for m in step_metrics])
                    total_frames = sum(m['total_frames'] for m in step_metrics)
                    elapsed = time.time() - epoch_start
                    audio_hours = total_frames / 100 / 3600  # 100fps mel
                    throughput = audio_hours / (elapsed / 3600) if elapsed > 0 else 0

                    log_dict = {
                        'train/loss': avg_loss,
                        'train/lr': current_lr,
                        'train/grad_norm': grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm,
                        'train/throughput_rtx': throughput,
                        'train/step': self.global_step,
                        'train/epoch': epoch,
                    }

                    logger.info(
                        f"Step {self.global_step} | loss={avg_loss:.4f} | "
                        f"lr={current_lr:.2e} | grad={grad_norm:.2f} | "
                        f"throughput={throughput:.1f}x RT"
                    )

                    if self.wandb:
                        self.wandb.log(log_dict, step=self.global_step)

                    step_metrics = []

                # Save checkpoint
                if self.global_step % config.save_every_steps == 0:
                    self.save_checkpoint(f"step-{self.global_step}")

                # Evaluation
                if self.global_step % config.eval_every_steps == 0:
                    self.evaluate()
                    self.model.train()

    def evaluate(self):
        """Run evaluation on dev set. Override with actual dev set logic."""
        if not self.is_main:
            return

        self.model.eval()

        # Apply EMA weights for evaluation
        if self.ema:
            self.ema.apply_shadow(self.raw_model)

        # TODO: Add actual WER evaluation on dev set
        # For now, log a placeholder
        logger.info(f"Step {self.global_step}: Evaluation placeholder — implement WER eval on dev set")

        # Restore original weights
        if self.ema:
            self.ema.restore(self.raw_model)

    def save_checkpoint(self, tag: str):
        """Save model checkpoint."""
        if dist.is_initialized():
            dist.barrier()

        if self.is_main:
            import shutil

            output_dir = Path(self.config.output_dir)
            ckpt_dir = output_dir / tag
            tmp_dir = output_dir / f".{tag}.tmp"

            if tmp_dir.exists():
                shutil.rmtree(tmp_dir)
            tmp_dir.mkdir(parents=True, exist_ok=True)

            # Save into a temporary directory first so interrupted saves do not
            # leave behind a half-written "latest" checkpoint.
            self.raw_model.save_pretrained(tmp_dir / "model")
            self.processor.save_pretrained(tmp_dir / "processor")

            state = {
                'global_step': self.global_step,
                'epoch': self.epoch,
                'best_wer': self.best_wer,
                'optimizer': self.optimizer.state_dict(),
            }
            if self.ema:
                state['ema'] = self.ema.state_dict()

            torch.save(state, tmp_dir / "training_state.pt")

            if ckpt_dir.exists():
                shutil.rmtree(ckpt_dir)
            tmp_dir.rename(ckpt_dir)
            logger.info(f"Checkpoint saved: {ckpt_dir}")

            # Push to R2 every 20K steps (async, non-blocking)
            if self.global_step > 0 and self.global_step % 20000 == 0:
                import subprocess
                r2_script = Path(__file__).parent / "push_to_r2.sh"
                if r2_script.exists():
                    subprocess.Popen(
                        ["bash", str(r2_script), str(ckpt_dir), str(self.global_step)],
                        stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL,
                    )
                    logger.info(f"R2 push started for step {self.global_step} (async)")

            # Cleanup old checkpoints
            self._cleanup_checkpoints()

        if dist.is_initialized():
            dist.barrier()

    def load_checkpoint(self, path: str):
        """Resume from checkpoint."""
        ckpt_dir = Path(path)

        model_file = ckpt_dir / "model" / "model.safetensors"
        if not model_file.exists():
            raise FileNotFoundError(f"Missing checkpoint model weights: {model_file}")

        # Resume must restore the actual trained weights, not just optimizer state.
        model_state = load_safetensors(str(model_file), device="cpu")
        self.raw_model.load_state_dict(model_state, strict=True)

        state = torch.load(ckpt_dir / "training_state.pt", map_location="cpu")
        self.global_step = state['global_step']
        self.epoch = state.get('epoch', 0)
        self.best_wer = state.get('best_wer', float('inf'))
        self.optimizer.load_state_dict(state['optimizer'])
        if self.ema and 'ema' in state:
            self.ema.load_state_dict(state['ema'])
        if self.is_main:
            logger.info(f"Resumed from {path}, step={self.global_step}")
        if dist.is_initialized():
            dist.barrier()

    def _cleanup_checkpoints(self):
        """Keep only the last N checkpoints."""
        output_dir = Path(self.config.output_dir)
        ckpts = sorted(
            [d for d in output_dir.iterdir() if d.is_dir() and d.name.startswith("step-")],
            key=lambda d: int(d.name.split("-")[1]),
        )
        while len(ckpts) > self.config.max_checkpoints:
            old = ckpts.pop(0)
            import shutil
            shutil.rmtree(old)
            logger.info(f"Removed old checkpoint: {old}")

    def train(self):
        """Main training loop."""
        config = self.config

        # Estimate total steps from manifest and batch config
        # avg_frames per utterance ~700 (7s * 100fps), frame budget = 1.2M
        # → ~1714 utts/batch, with grad_accum and 8 GPUs:
        # effective_batch = 1714 * grad_accum * world_size
        import pandas as pd
        try:
            n_utterances = len(pd.read_parquet(config.data.manifest_path, columns=["segment_id"]))
        except Exception:
            n_utterances = 74_500_000  # fallback
        avg_frames = 700  # ~7 seconds at 100fps, empirical average
        utts_per_batch_by_frames = config.data.max_batch_mel_frames // avg_frames
        # Account for the hard utterance cap if set
        if config.data.max_batch_utterances > 0:
            utts_per_batch = min(utts_per_batch_by_frames, config.data.max_batch_utterances)
        else:
            utts_per_batch = utts_per_batch_by_frames
        steps_per_epoch = n_utterances // (utts_per_batch * config.gradient_accumulation_steps * max(self.world_size, 1))
        total_steps = steps_per_epoch * config.epochs

        # Compute warmup_steps from ratio if not explicitly set
        if config.warmup_steps == 0 and config.warmup_ratio > 0:
            config.warmup_steps = int(total_steps * config.warmup_ratio)

        if self.is_main:
            logger.info(f"Starting training: {config.epochs} epochs, ~{total_steps} estimated steps")
            logger.info(f"Warmup: {config.warmup_steps} steps ({config.warmup_ratio*100:.0f}%)")
            logger.info(f"Strategy: {config.strategy}, BF16: {config.bf16}")
            logger.info(f"Gradient accumulation: {config.gradient_accumulation_steps}")

        for epoch in range(self.epoch, config.epochs):
            self.epoch = epoch
            if self.is_main:
                logger.info(f"=== Epoch {epoch + 1}/{config.epochs} ===")

            self.train_epoch(epoch, total_steps)

            # End-of-epoch checkpoint
            self.save_checkpoint(f"epoch-{epoch}")
            self.evaluate()

        # Final save with EMA weights
        if self.ema and self.is_main:
            self.ema.apply_shadow(self.raw_model)
            self.save_checkpoint("final-ema")
            self.ema.restore(self.raw_model)

        self.save_checkpoint("final")

        if self.is_main:
            logger.info("Training complete!")
            if self.wandb:
                self.wandb.finish()


# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------

def load_config(config_path: Optional[str] = None) -> TrainConfig:
    """Load config from YAML or use defaults."""
    config = TrainConfig()

    if config_path:
        import yaml
        with open(config_path) as f:
            cfg_dict = yaml.safe_load(f)

        # Apply overrides with type coercion to match dataclass field types
        import dataclasses
        field_types = {f.name: f.type for f in dataclasses.fields(config)}
        data_field_types = {f.name: f.type for f in dataclasses.fields(config.data)}

        for key, value in cfg_dict.items():
            if key == 'data':
                for dk, dv in value.items():
                    expected = data_field_types.get(dk)
                    if expected == float and not isinstance(dv, float):
                        dv = float(dv)
                    setattr(config.data, dk, dv)
            elif hasattr(config, key):
                expected = field_types.get(key)
                if expected == float and not isinstance(value, float):
                    value = float(value)
                setattr(config, key, value)

    return config


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, default=None, help='Path to config YAML')
    args = parser.parse_args()

    config = load_config(args.config)
    trainer = Trainer(config)
    trainer.train()


if __name__ == '__main__':
    main()
