#!/usr/bin/env python3
"""Production training script with optimized data loading + TDT + differential LR.

Data strategy (ported from parakeet pipeline):
  - Duration bucketing: groups similar-length clips → minimal padding waste
  - Language temperature rebalancing (T=0.3): upweights low-resource languages
  - Max-batch-duration packing: caps audio per micro-batch → stable VRAM
  - Compute-cost DDP sharding: balances load across GPUs
  - Grad accum alignment: prevents DDP sync bugs

Optimizations v2:
  - TDT loss (matching pretrained Parakeet encoder objective)
  - Differential LR: encoder=2e-5, decoder/joint/CTC=1e-3
  - Encoder frozen for first 5K steps, then unfrozen
  - CTC loss warmup: 0→0.3 over first 3K steps
  - Language embedding conditioning
  - OOM-safe VRAM ceiling design

Audio I/O: TarOffsetReader (os.pread, zero extraction from existing tars).
Model: NeMo EncDecHybridRNNTCTCBPEModel (FastConformer + TDT + aux CTC).

Usage:
  # Full production run (8×H200)
  python3 scripts/train_prod.py --config configs/train/stage1_prod_8xh200.yaml

  # Quick smoke test
  python3 scripts/train_prod.py --config configs/train/stage1_prod_8xh200.yaml \
    --max-steps 100 --devices 1 --log-every 10 --val-every 50
"""

import argparse
import os
import sys
import time
from pathlib import Path

import lightning.pytorch as pl
import nemo.collections.asr as nemo_asr
import numpy as np
import pyarrow.parquet as pq
import torch
from nemo.utils.exp_manager import exp_manager
from omegaconf import OmegaConf, open_dict


# ---------------------------------------------------------------------------
# Language embedding module
# ---------------------------------------------------------------------------
LANG_TO_ID = {
    "hi": 0, "bn": 1, "ta": 2, "te": 3, "mr": 4, "gu": 5,
    "kn": 6, "ml": 7, "pa": 8, "or": 9, "as": 10, "en": 11,
}
NUM_LANGUAGES = len(LANG_TO_ID)


class LanguageEmbedding(torch.nn.Module):
    """Projects a language ID into a per-frame additive bias for the encoder."""

    def __init__(self, num_languages: int, embed_dim: int):
        super().__init__()
        self.embed = torch.nn.Embedding(num_languages, embed_dim)
        # Initialize near-zero so it doesn't disrupt pretrained encoder features
        torch.nn.init.normal_(self.embed.weight, mean=0.0, std=0.02)

    def forward(self, lang_ids: torch.Tensor) -> torch.Tensor:
        """
        Args:
            lang_ids: (B,) integer tensor of language IDs
        Returns:
            (B, embed_dim, 1) — broadcast-addable to encoder output (B, D, T)
        """
        # Encoder output is (B, D, T), so we need (B, D, 1) for broadcasting
        return self.embed(lang_ids).unsqueeze(2)


# ---------------------------------------------------------------------------
# Callbacks
# ---------------------------------------------------------------------------
class ThroughputCallback(pl.Callback):
    """Logs VRAM, throughput (samples/s, audio-s/s), and MFU every log_every steps."""

    def __init__(self, log_every: int = 10, model_params: int = 0,
                 grad_accum: int = 1, world_size: int = 1):
        self.log_every = log_every
        self.model_params = model_params
        self.grad_accum = grad_accum
        self.world_size = world_size
        self._window_t0 = None
        self._window_samples = 0
        self._window_audio_s = 0.0
        self._window_micro_steps = 0
        # H200 specs: bf16 peak ~990 TFLOPS
        self._gpu_peak_tflops = 989.0

    def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
        if self._window_t0 is None:
            self._window_t0 = time.time()

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        if batch is not None and len(batch) >= 2:
            sig_lens = batch[1]
            n_samples = sig_lens.shape[0]
            audio_s = sig_lens.float().sum().item() / 16000.0
            self._window_samples += n_samples
            self._window_audio_s += audio_s
        self._window_micro_steps += 1

        rank = int(os.environ.get("LOCAL_RANK", 0))
        global_step = trainer.global_step

        if rank == 0 and global_step > 0 and global_step % self.log_every == 0 and self._window_t0 is not None:
            elapsed = time.time() - self._window_t0
            if elapsed > 0:
                samples_per_s = self._window_samples / elapsed
                audio_s_per_s = self._window_audio_s / elapsed
                rtf_per_gpu = audio_s_per_s / max(self.world_size, 1)
                opt_steps_per_s = (self._window_micro_steps / max(self.grad_accum, 1)) / elapsed

                vram_alloc_gb = torch.cuda.memory_allocated() / 1e9
                vram_reserved_gb = torch.cuda.memory_reserved() / 1e9
                vram_peak_gb = torch.cuda.max_memory_allocated() / 1e9

                mel_frames_per_s = audio_s_per_s * 100
                estimated_tflops = (6 * self.model_params * mel_frames_per_s) / 1e12
                mfu = estimated_tflops / (self._gpu_peak_tflops * self.world_size) * 100

                print(f"  [step {global_step:>6d}] "
                      f"samples/s={samples_per_s:.1f}  "
                      f"audio-s/s={audio_s_per_s:.0f}  "
                      f"RTF/GPU={rtf_per_gpu:.1f}  "
                      f"opt-steps/s={opt_steps_per_s:.3f}  "
                      f"VRAM={vram_alloc_gb:.1f}/{vram_reserved_gb:.1f}GB  "
                      f"peak={vram_peak_gb:.1f}GB  "
                      f"MFU≈{mfu:.1f}%")

                # Log custom metrics to wandb/tensorboard via Lightning
                if pl_module is not None:
                    custom_metrics = {
                        "throughput/samples_per_s": samples_per_s,
                        "throughput/audio_s_per_s": audio_s_per_s,
                        "throughput/rtf_per_gpu": rtf_per_gpu,
                        "throughput/opt_steps_per_s": opt_steps_per_s,
                        "throughput/mfu_pct": mfu,
                        "memory/vram_alloc_gb": vram_alloc_gb,
                        "memory/vram_reserved_gb": vram_reserved_gb,
                        "memory/vram_peak_gb": vram_peak_gb,
                    }
                    # Log LR per param group
                    optimizer = trainer.optimizers[0] if trainer.optimizers else None
                    if optimizer:
                        for i, pg in enumerate(optimizer.param_groups):
                            name = pg.get('name', f'group_{i}')
                            custom_metrics[f"lr/{name}"] = pg['lr']
                    pl_module.log_dict(custom_metrics, prog_bar=False)

            self._window_t0 = time.time()
            self._window_samples = 0
            self._window_audio_s = 0.0
            self._window_micro_steps = 0


class EncoderFreezeCallback(pl.Callback):
    """Freezes encoder for the first N optimizer steps, then unfreezes."""

    def __init__(self, freeze_steps: int = 5000):
        self.freeze_steps = freeze_steps
        self._frozen = False
        self._unfrozen = False

    def on_train_start(self, trainer, pl_module):
        if self.freeze_steps > 0:
            self._freeze_encoder(pl_module)

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        if self._unfrozen:
            return
        if trainer.global_step >= self.freeze_steps and self._frozen:
            self._unfreeze_encoder(pl_module)

    def _freeze_encoder(self, model):
        frozen = 0
        for name, param in model.named_parameters():
            if name.startswith("encoder."):
                param.requires_grad = False
                frozen += 1
        self._frozen = True
        rank = int(os.environ.get("LOCAL_RANK", 0))
        if rank == 0:
            print(f"  [EncoderFreeze] Froze {frozen} encoder params for first {self.freeze_steps} steps")

    def _unfreeze_encoder(self, model):
        unfrozen = 0
        for name, param in model.named_parameters():
            if name.startswith("encoder."):
                param.requires_grad = True
                unfrozen += 1
        self._frozen = False
        self._unfrozen = True
        rank = int(os.environ.get("LOCAL_RANK", 0))
        if rank == 0:
            print(f"  [EncoderFreeze] Unfroze {unfrozen} encoder params at step {model.trainer.global_step}")


class SamplerStateCallback(pl.Callback):
    """Saves sampler state (consumed batches) alongside checkpoints.

    On resume, the sampler can skip already-trained batches to avoid
    reprocessing the same data.
    """
    STATE_FILE = "sampler_state.json"

    def __init__(self, sampler, state_dir: str = "/tmp"):
        self.sampler = sampler
        self.state_dir = state_dir

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        # Save state every 1000 optimizer steps
        if trainer.global_step > 0 and trainer.global_step % 1000 == 0:
            self._save_state(trainer.global_step)

    def on_save_checkpoint(self, trainer, pl_module, checkpoint):
        self._save_state(trainer.global_step)

    def _save_state(self, global_step):
        import json
        state = {
            "global_step": global_step,
            "epoch": self.sampler.epoch,
            "consumed_batches": getattr(self.sampler, "_consumed_batches", 0),
        }
        state_path = os.path.join(self.state_dir, self.STATE_FILE)
        with open(state_path, "w") as f:
            json.dump(state, f)

    @staticmethod
    def load_state(state_dir: str = "/tmp") -> dict | None:
        import json
        state_path = os.path.join(state_dir, SamplerStateCallback.STATE_FILE)
        if os.path.exists(state_path):
            with open(state_path) as f:
                return json.load(f)
        return None


class CTCWarmupCallback(pl.Callback):
    """Warms up CTC loss weight from 0 to target over N steps.

    Prevents random CTC gradients from corrupting pretrained encoder in early training.
    """

    def __init__(self, target_weight: float = 0.3, warmup_steps: int = 3000):
        self.target_weight = target_weight
        self.warmup_steps = warmup_steps

    def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
        step = trainer.global_step
        if step < self.warmup_steps:
            weight = self.target_weight * (step / self.warmup_steps)
        else:
            weight = self.target_weight

        if hasattr(pl_module, 'ctc_loss_weight'):
            pl_module.ctc_loss_weight = weight


class DifferentialLRCallback(pl.Callback):
    """Sets differential learning rates for encoder vs decoder/joint/CTC param groups.

    After NeMo constructs the optimizer (which uses a single LR), this callback
    modifies the param groups to use different LRs.

    Encoder (pretrained): low LR to preserve features
    Decoder/Joint/CTC (random init): high LR for fast adaptation
    """

    def __init__(self, encoder_lr: float = 2e-5, head_lr: float = 1e-3):
        self.encoder_lr = encoder_lr
        self.head_lr = head_lr

    def on_train_start(self, trainer, pl_module):
        optimizer = trainer.optimizers[0]

        # NeMo creates a single param group. We need to split it into
        # encoder (low LR) and head (high LR) groups.
        # Include ALL params regardless of requires_grad (freeze callback
        # will toggle requires_grad but params must be in optimizer).

        encoder_params = []
        head_params = []

        for name, param in pl_module.named_parameters():
            if name.startswith("encoder."):
                encoder_params.append(param)
            else:
                head_params.append(param)

        # Preserve optimizer settings from existing group
        base_group = optimizer.param_groups[0].copy()

        optimizer.param_groups.clear()

        encoder_group = base_group.copy()
        encoder_group['params'] = encoder_params
        encoder_group['lr'] = self.encoder_lr
        encoder_group['name'] = 'encoder'

        head_group = base_group.copy()
        head_group['params'] = head_params
        head_group['lr'] = self.head_lr
        head_group['name'] = 'decoder_joint_ctc'

        optimizer.param_groups.append(encoder_group)
        optimizer.param_groups.append(head_group)

        # Fix LR scheduler: it was initialized with 1 group, now we have 2.
        # Update base_lrs and last_epoch tracking for the scheduler.
        if trainer.lr_scheduler_configs:
            for sched_config in trainer.lr_scheduler_configs:
                scheduler = sched_config.scheduler
                # Update base_lrs to match new param groups
                scheduler.base_lrs = [self.encoder_lr, self.head_lr]
                # Update _last_lr if it exists
                if hasattr(scheduler, '_last_lr'):
                    scheduler._last_lr = [self.encoder_lr, self.head_lr]

        rank = int(os.environ.get("LOCAL_RANK", 0))
        if rank == 0:
            enc_count = sum(p.numel() for p in encoder_params)
            head_count = sum(p.numel() for p in head_params)
            enc_trainable = sum(p.numel() for p in encoder_params if p.requires_grad)
            print(f"  [DiffLR] Encoder: {enc_count/1e6:.0f}M params @ lr={self.encoder_lr} ({enc_trainable/1e6:.0f}M trainable)")
            print(f"  [DiffLR] Heads:   {head_count/1e6:.0f}M params @ lr={self.head_lr}")


# ---------------------------------------------------------------------------
# OOM-safe VRAM budget calculator
# ---------------------------------------------------------------------------
def compute_vram_safe_batch_params(
    vocab_size: int = 32769,
    enc_hidden: int = 1024,
    joint_hidden: int = 640,
    num_durations: int = 5,
    max_audio_dur: float = 20.0,
    subsampling_factor: int = 8,
    sample_rate: int = 16000,
    window_stride: float = 0.01,
    gpu_memory_gb: float = 140.0,
    model_params_gb: float = 4.5,  # ~1.1B params in bf16
    safety_margin: float = 0.75,  # use at most 75% of available
    grad_accum: int = 4,
) -> dict:
    """Compute OOM-safe batch parameters based on VRAM budget.

    The dominant VRAM consumer is the TDT joint tensor:
      joint_size = B × T_enc × T_dec × (vocab_size + 1 + num_durations)
    For TDT, T_dec is much smaller than RNNT because of duration skipping.

    Returns dict with recommended max_batch_dur, max_batch_size.
    """
    # Available VRAM for activations (after model + optimizer + gradients)
    optimizer_gb = model_params_gb * 3  # AdamW: params + momentum + variance (bf16 + fp32)
    available_gb = gpu_memory_gb - model_params_gb - optimizer_gb
    activation_budget_gb = available_gb * safety_margin

    # Worst-case encoder frames for max_audio_dur
    mel_frames = max_audio_dur / window_stride  # 2000 for 20s
    enc_frames = mel_frames / subsampling_factor  # 250 for 8x sub

    # TDT joint tensor: B × T_enc × T_dec × (V + 1 + num_durations)
    # For TDT, T_dec ≈ max_tokens (32K vocab = ~50 tokens for 20s audio)
    # Conservative: assume T_dec = 80 (worst case for long utterance)
    max_t_dec = 80
    output_dim = vocab_size + 1 + num_durations

    # Joint tensor size per sample in GB (bf16 = 2 bytes)
    joint_per_sample_gb = (enc_frames * max_t_dec * output_dim * 2) / 1e9
    # With gradient: 2x (forward + backward)
    joint_per_sample_gb *= 2

    # Add encoder activations per sample (~0.5 GB for 42-layer conformer on 20s)
    enc_activations_per_sample_gb = 0.5

    total_per_sample_gb = joint_per_sample_gb + enc_activations_per_sample_gb

    # Max batch size (samples)
    max_safe_batch_size = max(1, int(activation_budget_gb / total_per_sample_gb))

    # Max batch duration (seconds)
    # Average audio in a batch = max_batch_dur, average sample ~8s
    avg_sample_dur = 8.0
    avg_batch_size = max_safe_batch_size
    max_safe_batch_dur = avg_batch_size * avg_sample_dur

    return {
        "max_batch_size": min(max_safe_batch_size, 16),  # cap at 16 for safety
        "max_batch_dur": min(max_safe_batch_dur, 120.0),  # cap at 120s
        "activation_budget_gb": round(activation_budget_gb, 1),
        "joint_per_sample_gb": round(joint_per_sample_gb, 3),
        "total_per_sample_gb": round(total_per_sample_gb, 3),
        "gpu_memory_gb": gpu_memory_gb,
        "model_params_gb": round(model_params_gb, 1),
        "optimizer_gb": round(optimizer_gb, 1),
    }


# ---------------------------------------------------------------------------
# Dataset and collate
# ---------------------------------------------------------------------------
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from lf_asr.data.production_sampler import ProductionBatchSampler
from lf_asr.data.tar_offset_reader import TarOffsetReader


class NeMoTarOffsetDataset(torch.utils.data.Dataset):
    """NeMo-compatible dataset backed by TarOffsetReader.

    Returns tuples: (signal, signal_length, tokens, token_length)
    matching NeMo's EncDecHybridRNNTCTCBPEModel expectations.

    Also stores language IDs for language conditioning.
    """

    def __init__(self, manifest_parquet, tokenizer, sample_rate=16000,
                 max_duration=30.0, min_duration=0.5):
        df = pq.read_table(manifest_parquet).to_pandas()
        mask = (
            (df["duration_s"] >= min_duration)
            & (df["duration_s"] <= max_duration)
            & (df["transcript"].str.strip().str.len() > 0)
        )
        n_blank = (~(df["transcript"].str.strip().str.len() > 0)).sum()
        if n_blank > 0:
            print(f"  WARNING: Dropped {n_blank:,} blank transcripts at load time")
        n_dur = ((df["duration_s"] < min_duration) | (df["duration_s"] > max_duration)).sum()
        if n_dur > 0:
            print(f"  Filtered {n_dur:,} samples outside [{min_duration}, {max_duration}]s")

        self.df = df[mask].reset_index(drop=True)
        self.reader = TarOffsetReader(max_fd_cache=512)
        self.tokenizer = tokenizer
        self.sample_rate = sample_rate

        # Pre-extract numpy arrays for fast indexing
        self.tar_paths = self.df["tar_path"].values
        self.offsets = self.df["tar_offset_data"].values.astype(np.int64)
        self.nbytes = self.df["tar_nbytes"].values.astype(np.int64)
        self.transcripts = self.df["transcript"].values
        self.durations = self.df["duration_s"].values.astype(np.float32)
        self.languages = self.df["language"].values

        # Pre-compute language IDs
        self.lang_ids = np.array(
            [LANG_TO_ID.get(str(lang).strip().lower(), LANG_TO_ID["en"]) for lang in self.languages],
            dtype=np.int64,
        )

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        tar_path = str(self.tar_paths[idx])
        offset = int(self.offsets[idx])
        nbytes = int(self.nbytes[idx])

        try:
            waveform, sr = self.reader.read(tar_path, offset, nbytes)
        except Exception as e:
            print(f"  WARNING: Bad audio idx={idx} tar={tar_path} offset={offset}: {e}")
            return None
        if sr != self.sample_rate:
            print(f"  WARNING: Sample rate mismatch idx={idx}: {sr} != {self.sample_rate}")
            return None

        transcript = str(self.transcripts[idx]).strip()
        if not transcript:
            transcript = "<blank>"
        tokens = self.tokenizer.text_to_ids(transcript)

        lang_id = int(self.lang_ids[idx])

        return (
            torch.from_numpy(waveform).float(),
            torch.tensor(len(waveform), dtype=torch.long),
            torch.tensor(tokens, dtype=torch.long),
            torch.tensor(len(tokens), dtype=torch.long),
            torch.tensor(lang_id, dtype=torch.long),
        )


def collate_nemo_batch(batch):
    """Collate with padding — handles None samples from bad audio.

    Returns (signals, sig_lens, targets, tgt_lens, lang_ids).
    """
    batch = [b for b in batch if b is not None]
    if not batch:
        return (
            torch.zeros(1, 1),
            torch.tensor([1], dtype=torch.long),
            torch.zeros(1, 1, dtype=torch.long),
            torch.tensor([0], dtype=torch.long),
            torch.tensor([0], dtype=torch.long),
        )

    signals, sig_lens, targets, tgt_lens, lang_ids = zip(*batch)

    max_sig_len = max(s.shape[0] for s in signals)
    max_tgt_len = max(t.shape[0] for t in targets)

    padded_signals = torch.zeros(len(signals), max_sig_len)
    padded_targets = torch.zeros(len(targets), max_tgt_len, dtype=torch.long)
    signal_lengths = torch.stack(list(sig_lens))
    target_lengths = torch.stack(list(tgt_lens))
    lang_id_tensor = torch.stack(list(lang_ids))

    for i, (sig, tgt) in enumerate(zip(signals, targets)):
        padded_signals[i, :sig.shape[0]] = sig
        padded_targets[i, :tgt.shape[0]] = tgt

    return padded_signals, signal_lengths, padded_targets, target_lengths, lang_id_tensor


def get_ddp_info() -> tuple[int, int]:
    """Get (rank, world_size) from DDP environment."""
    if torch.distributed.is_available() and torch.distributed.is_initialized():
        return torch.distributed.get_rank(), torch.distributed.get_world_size()
    rank = int(os.environ.get("LOCAL_RANK", os.environ.get("RANK", 0)))
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    return rank, world_size


# ---------------------------------------------------------------------------
# Language-conditioned forward hook
# ---------------------------------------------------------------------------
def install_language_conditioning(model, lang_embed: LanguageEmbedding):
    """Install a forward hook on the encoder to add language embeddings.

    After the encoder produces its output, we add the language embedding
    as a per-frame bias. This gives the encoder a language signal without
    modifying the encoder architecture (preserving pretrained weights).
    """
    # Store reference on model so it moves to GPU with model
    model._lang_embed = lang_embed

    original_forward = model.encoder.forward

    def encoder_forward_with_lang(audio_signal, length, lang_ids=None):
        encoded, encoded_len = original_forward(audio_signal=audio_signal, length=length)
        if lang_ids is not None and hasattr(model, '_lang_embed'):
            lang_bias = model._lang_embed(lang_ids)  # (B, 1, D)
            encoded = encoded + lang_bias
        return encoded, encoded_len

    model.encoder.forward = encoder_forward_with_lang
    return model


# ---------------------------------------------------------------------------
# Custom training step with language conditioning
# ---------------------------------------------------------------------------
def patch_training_step(model):
    """Patch the model's training_step to extract lang_ids from batch
    and pass them through to the encoder via language conditioning.
    """
    original_training_step = model.training_step

    def patched_training_step(batch, batch_idx):
        # NeMo expects batch = (signal, signal_len, tokens, token_len)
        # We added lang_ids as 5th element
        if isinstance(batch, (list, tuple)) and len(batch) >= 5:
            signal, signal_len, tokens, token_len, lang_ids = batch[:5]
            # Store lang_ids on model for encoder hook to pick up
            model._current_lang_ids = lang_ids.to(signal.device)
            # Pass standard 4-tuple to NeMo
            nemo_batch = (signal, signal_len, tokens, token_len)
            return original_training_step(nemo_batch, batch_idx)
        return original_training_step(batch, batch_idx)

    model.training_step = patched_training_step


def patch_encoder_for_lang(model):
    """Patch encoder forward to inject language IDs stored on model."""
    original_encoder_forward = model.encoder.forward

    def encoder_forward_with_lang(audio_signal, length):
        encoded, encoded_len = original_encoder_forward(audio_signal=audio_signal, length=length)
        if (hasattr(model, '_current_lang_ids') and hasattr(model, '_lang_embed')
                and model._current_lang_ids is not None):
            lang_ids = model._current_lang_ids
            # Safety: only apply if batch sizes match (skip during validation)
            if lang_ids.shape[0] == encoded.shape[0]:
                lang_bias = model._lang_embed(lang_ids)  # (B, D, 1)
                encoded = encoded + lang_bias
            # Clear after use to prevent stale values during validation
            model._current_lang_ids = None
        return encoded, encoded_len

    model.encoder.forward = encoder_forward_with_lang


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def _tokenize_chunk_worker(texts, tokenizer_model_path):
    """Tokenize a chunk of texts in a worker process. Module-level for pickling."""
    import sentencepiece as spm
    sp = spm.SentencePieceProcessor()
    sp.Load(tokenizer_model_path)
    return [len(sp.EncodeAsIds(str(t).strip() or "<blank>")) for t in texts]


def main():
    parser = argparse.ArgumentParser(
        description="Production NeMo ASR training with TDT + differential LR"
    )
    parser.add_argument("--config", type=str, required=True)
    parser.add_argument("--max-steps", type=int, default=None)
    parser.add_argument("--devices", type=int, default=None)
    parser.add_argument("--train-parquet", type=str,
                        default="artifacts/phase3/production_train_final.parquet")
    parser.add_argument("--val-manifest", type=str,
                        default="data/manifests/stage1_prod_val_v2.jsonl")
    parser.add_argument("--resume-from-checkpoint", type=str, default=None)
    parser.add_argument("--resume-weights-only", action="store_true", default=False,
                        help="Resume model weights from checkpoint but reset optimizer (for config changes)")
    parser.add_argument("--pretrained-encoder", type=str, default=None,
                        help="Path to pretrained encoder weights (.pt file)")
    parser.add_argument("--smoke", action="store_true", default=False)
    parser.add_argument("--val-every", type=int, default=None)
    parser.add_argument("--log-every", type=int, default=None)

    # Data strategy args
    parser.add_argument("--max-batch-dur", type=float, default=120.0,
                        help="Max audio seconds per micro-batch (default: 120s)")
    parser.add_argument("--max-batch-size", type=int, default=16,
                        help="Global hard cap on samples per micro-batch")
    parser.add_argument("--max-tokens-in-batch", type=int, default=400,
                        help="Max total tokens per batch (controls joint tensor T_dec)")
    parser.add_argument("--temperature", type=float, default=0.3)
    parser.add_argument("--grad-accum", type=int, default=None)
    parser.add_argument("--num-workers", type=int, default=None)

    # v2 optimization args
    parser.add_argument("--encoder-lr", type=float, default=2e-5,
                        help="Learning rate for pretrained encoder")
    parser.add_argument("--head-lr", type=float, default=1e-3,
                        help="Learning rate for decoder/joint/CTC heads")
    parser.add_argument("--freeze-encoder-steps", type=int, default=5000,
                        help="Freeze encoder for this many optimizer steps (0=no freeze)")
    parser.add_argument("--ctc-warmup-steps", type=int, default=3000,
                        help="Steps to warm up CTC loss from 0 to target weight")
    parser.add_argument("--no-lang-embed", action="store_true", default=False,
                        help="Disable language embedding conditioning")
    args = parser.parse_args()

    cfg = OmegaConf.load(args.config)

    with open_dict(cfg):
        if args.max_steps is not None:
            cfg.trainer.max_steps = args.max_steps
        if args.devices is not None:
            cfg.trainer.devices = args.devices
            cfg.trainer.strategy = "ddp_find_unused_parameters_true" if args.devices > 1 else "auto"
        if args.grad_accum is not None:
            cfg.trainer.accumulate_grad_batches = args.grad_accum
        if args.val_every is not None:
            cfg.trainer.val_check_interval = args.val_every
            cfg.exp_manager.checkpoint_callback_params.every_n_train_steps = args.val_every
        if args.log_every is not None:
            cfg.trainer.log_every_n_steps = args.log_every
        if args.smoke and args.val_every is None:
            cfg.trainer.val_check_interval = 1.0
        if args.resume_from_checkpoint is not None:
            cfg.exp_manager.resume_from_checkpoint = args.resume_from_checkpoint
            cfg.exp_manager.resume_if_exists = True

        cfg.model.train_ds.is_tarred = False
        for key in ("tarred_audio_filepaths", "shuffle_n", "input_cfg"):
            if key in cfg.model.train_ds:
                del cfg.model.train_ds[key]
        cfg.model.train_ds.manifest_filepath = args.val_manifest
        cfg.model.validation_ds.manifest_filepath = args.val_manifest

    grad_accum = cfg.trainer.get("accumulate_grad_batches", 1)
    devices = cfg.trainer.get("devices", 1)
    max_dur = cfg.model.train_ds.get("max_duration", 20.0)
    min_dur = cfg.model.train_ds.get("min_duration", 0.5)
    num_workers = args.num_workers or cfg.model.train_ds.get("num_workers", 8)

    rank = int(os.environ.get("LOCAL_RANK", 0))
    world_size = devices if isinstance(devices, int) else 1

    # -----------------------------------------------------------------------
    # OOM-safe VRAM budget
    # -----------------------------------------------------------------------
    vram_budget = compute_vram_safe_batch_params(
        gpu_memory_gb=140.0,
        max_audio_dur=max_dur,
        grad_accum=grad_accum,
    )
    max_batch_dur = args.max_batch_dur
    max_batch_size = args.max_batch_size

    if rank == 0:
        print("=" * 60)
        print("  Maya ASR — Production Training v2 (TDT + DiffLR)")
        print("=" * 60)
        print()
        print(f"Config:           {args.config}")
        print(f"Max steps:        {cfg.trainer.max_steps}")
        print(f"Devices:          {devices}")
        print(f"Grad accum:       {grad_accum}")
        print(f"Precision:        {cfg.trainer.get('precision', 'bf16-mixed')}")
        print(f"Train parquet:    {args.train_parquet}")
        print(f"Val manifest:     {args.val_manifest}")
        print()
        print("v2 Optimizations:")
        print(f"  Loss:           TDT (durations=[0,1,2,3,4])")
        print(f"  Encoder LR:     {args.encoder_lr}")
        print(f"  Head LR:        {args.head_lr}")
        print(f"  Encoder freeze: {args.freeze_encoder_steps} steps")
        print(f"  CTC warmup:     0→target over {args.ctc_warmup_steps} steps")
        print(f"  Lang embed:     {'yes' if not args.no_lang_embed else 'no'}")
        print()
        print("VRAM Budget (OOM-safe):")
        print(f"  GPU memory:     {vram_budget['gpu_memory_gb']:.0f} GB")
        print(f"  Model (bf16):   {vram_budget['model_params_gb']:.1f} GB")
        print(f"  Optimizer:      {vram_budget['optimizer_gb']:.1f} GB")
        print(f"  Activation:     {vram_budget['activation_budget_gb']:.1f} GB (75% safety)")
        print(f"  Joint/sample:   {vram_budget['joint_per_sample_gb']:.3f} GB")
        print(f"  Total/sample:   {vram_budget['total_per_sample_gb']:.3f} GB")
        print(f"  → max_batch_size: {max_batch_size}")
        print(f"  → max_batch_dur:  {max_batch_dur:.0f}s")
        print()
        print("Data strategy:")
        print(f"  Max batch dur:  {max_batch_dur}s")
        print(f"  Max batch size: {max_batch_size}")
        print(f"  Temperature:    {args.temperature}")
        print(f"  Duration range: [{min_dur}, {max_dur}]s")
        print(f"  Num workers:    {num_workers}")
        print()

    # -----------------------------------------------------------------------
    # Build trainer and model
    # -----------------------------------------------------------------------
    trainer = pl.Trainer(**cfg.trainer, logger=False, use_distributed_sampler=False)
    exp_manager(trainer, cfg.get("exp_manager", None))

    model = nemo_asr.models.EncDecHybridRNNTCTCBPEModel(cfg=cfg.model, trainer=trainer)
    params = sum(p.numel() for p in model.parameters())
    if rank == 0:
        print(f"Model parameters: {params:,}")

    # Resume weights from checkpoint (without optimizer state)
    if args.resume_weights_only and args.resume_from_checkpoint:
        ckpt_path = args.resume_from_checkpoint
        if rank == 0:
            print(f"\nResuming weights only from: {ckpt_path}")
        ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
        model.load_state_dict(ckpt["state_dict"], strict=False)
        resume_step = ckpt.get("global_step", 0)
        if rank == 0:
            print(f"  Loaded model weights from step {resume_step}")
            print(f"  Optimizer state NOT restored (fresh optimizer with DiffLR)")
        # Clear resume flag so Lightning doesn't also try to restore
        args.resume_from_checkpoint = None
        with open_dict(cfg):
            cfg.exp_manager.resume_from_checkpoint = None
            cfg.exp_manager.resume_if_exists = False
        del ckpt
        print()

    # Load pretrained encoder weights
    if args.pretrained_encoder is not None and not args.resume_weights_only:
        if rank == 0:
            print(f"\nLoading pretrained encoder from: {args.pretrained_encoder}")
        pretrained_state = torch.load(args.pretrained_encoder, map_location="cpu", weights_only=True)
        model_state = model.state_dict()

        loaded, skipped = 0, 0
        for key, value in pretrained_state.items():
            if key in model_state and model_state[key].shape == value.shape:
                model_state[key] = value
                loaded += 1
            else:
                skipped += 1
                if rank == 0 and skipped <= 10:
                    expected = model_state[key].shape if key in model_state else "NOT FOUND"
                    print(f"  Skip: {key} pretrained={value.shape} model={expected}")

        model.load_state_dict(model_state)
        enc_params = sum(pretrained_state[k].numel() for k in pretrained_state if k in model_state and model_state[k].shape == pretrained_state[k].shape)
        if rank == 0:
            print(f"  Loaded {loaded} tensors ({enc_params/1e6:.0f}M params), skipped {skipped}")
            print(f"  Decoder + joint + CTC head: randomly initialized for {model.tokenizer.vocab_size} vocab")
        print()

    # -----------------------------------------------------------------------
    # Language conditioning
    # -----------------------------------------------------------------------
    if not args.no_lang_embed:
        lang_embed = LanguageEmbedding(NUM_LANGUAGES, cfg.model.model_defaults.enc_hidden)
        model._lang_embed = lang_embed
        # Register as a proper submodule so it moves to GPU and is saved
        model.register_module('_lang_embed_module', lang_embed)
        patch_encoder_for_lang(model)
        patch_training_step(model)
        if rank == 0:
            lang_params = sum(p.numel() for p in lang_embed.parameters())
            print(f"  [LangEmbed] {NUM_LANGUAGES} languages → {cfg.model.model_defaults.enc_hidden}D ({lang_params:,} params)")

    # -----------------------------------------------------------------------
    # Callbacks
    # -----------------------------------------------------------------------
    log_interval = args.log_every or cfg.trainer.get("log_every_n_steps", 100)
    throughput_cb = ThroughputCallback(
        log_every=log_interval,
        model_params=params,
        grad_accum=grad_accum,
        world_size=world_size,
    )
    trainer.callbacks.append(throughput_cb)

    freeze_cb = EncoderFreezeCallback(freeze_steps=args.freeze_encoder_steps)
    trainer.callbacks.append(freeze_cb)

    ctc_warmup_cb = CTCWarmupCallback(
        target_weight=cfg.model.aux_ctc.get("ctc_loss_weight", 0.3),
        warmup_steps=args.ctc_warmup_steps,
    )
    trainer.callbacks.append(ctc_warmup_cb)

    diff_lr_cb = DifferentialLRCallback(
        encoder_lr=args.encoder_lr,
        head_lr=args.head_lr,
    )
    trainer.callbacks.append(diff_lr_cb)

    # -----------------------------------------------------------------------
    # Load train dataset
    # -----------------------------------------------------------------------
    if rank == 0:
        print(f"Loading train parquet: {args.train_parquet}")
    train_dataset = NeMoTarOffsetDataset(
        manifest_parquet=args.train_parquet,
        tokenizer=model.tokenizer,
        sample_rate=cfg.model.train_ds.get("sample_rate", 16000),
        max_duration=max_dur,
        min_duration=min_dur,
    )
    if rank == 0:
        print(f"Train samples: {len(train_dataset):,}")
        print(f"Train audio:   {train_dataset.durations.sum() / 3600:.0f} hours")
        # Language distribution
        unique_langs, lang_counts = np.unique(train_dataset.languages, return_counts=True)
        print(f"Languages:     {len(unique_langs)} ({', '.join(sorted(unique_langs))})")
        print()

    # Create production batch sampler
    # Pre-compute token lengths for VRAM-safe batching
    # Token length controls T_dec in the joint tensor — the KEY factor for OOM
    # Cached to .npy file so subsequent runs skip tokenization entirely
    cache_path = Path(args.train_parquet).with_suffix(".token_lengths.npy")
    if cache_path.exists():
        if rank == 0:
            print(f"Loading cached token lengths from {cache_path}")
        token_lengths = np.load(cache_path)
        if len(token_lengths) != len(train_dataset):
            if rank == 0:
                print(f"  Cache size mismatch ({len(token_lengths)} vs {len(train_dataset)}), recomputing...")
            cache_path.unlink()
            token_lengths = None
        else:
            if rank == 0:
                print(f"  Loaded {len(token_lengths):,} cached token lengths")

    if not cache_path.exists():
        if rank == 0:
            print("Pre-computing token lengths (parallel, 32 workers)...")
        import multiprocessing as mp

        tok_dir = cfg.model.tokenizer.dir
        tok_model = str(Path(tok_dir) / "tokenizer.model")

        transcripts = list(train_dataset.transcripts)
        n_workers = min(32, mp.cpu_count())
        chunk_size = (len(transcripts) + n_workers - 1) // n_workers
        chunks = [transcripts[i:i + chunk_size] for i in range(0, len(transcripts), chunk_size)]

        if rank == 0:
            t_start = time.time()

        with mp.Pool(n_workers) as pool:
            results = pool.starmap(
                _tokenize_chunk_worker,
                [(chunk, tok_model) for chunk in chunks],
            )

        token_lengths = np.array([l for chunk_result in results for l in chunk_result], dtype=np.int32)

        if rank == 0:
            elapsed = time.time() - t_start
            print(f"  Tokenized {len(token_lengths):,} samples in {elapsed:.1f}s "
                  f"({len(token_lengths)/elapsed:.0f} samples/s, {n_workers} workers)")
            # Cache for next run
            np.save(cache_path, token_lengths)
            print(f"  Cached to {cache_path}")

    if rank == 0:
        print(f"  Token lengths: mean={token_lengths.mean():.1f}, "
              f"max={token_lengths.max()}, p99={np.percentile(token_lengths, 99):.0f}")

    sampler = ProductionBatchSampler(
        durations=train_dataset.durations,
        languages=train_dataset.languages,
        token_lengths=token_lengths,
        bucket_boundaries=[4.0, 8.0, 12.0, 16.0],
        max_batch_duration=max_batch_dur,
        max_batch_size=max_batch_size,
        max_tokens_in_batch=args.max_tokens_in_batch,
        temperature=args.temperature,
        rank=rank,
        world_size=world_size,
        grad_accum=grad_accum,
        shuffle=True,
        seed=42,
        drop_last=True,
    )

    # Resume sampler state: skip already-trained batches
    if args.resume_from_checkpoint:
        sampler_state = SamplerStateCallback.load_state()
        if sampler_state and sampler_state.get("consumed_batches", 0) > 0:
            skip = sampler_state["consumed_batches"]
            sampler.set_resume_state(consumed_batches=skip)
            if rank == 0:
                print(f"  [Resume] Will skip {skip:,} already-consumed batches")

    # Add sampler state callback
    sampler_state_cb = SamplerStateCallback(sampler=sampler)
    trainer.callbacks.append(sampler_state_cb)

    if rank == 0:
        lang_dist = sampler.get_language_distribution()
        print("Language rebalancing (T={:.1f}):".format(args.temperature))
        print(f"  {'Lang':<6} {'Natural%':>9} {'Rebal%':>8} {'Boost':>7}")
        for lang, info in sorted(lang_dist.items(), key=lambda x: -x[1]["count"]):
            print(f"  {lang:<6} {info['natural_pct']:>8.1f}% {info['rebalanced_pct']:>7.1f}% {info['boost_factor']:>6.1f}x")
        print()

        # Dry-run sampler for stats (only rank 0)
        _ = list(sampler)
        sampler.print_stats(prefix="  ")
        sampler.epoch = 0  # reset

        eff_audio_per_step = max_batch_dur * world_size * grad_accum
        total_audio_h = train_dataset.durations.sum() / 3600
        steps_per_epoch = sampler._last_stats.get("optimizer_steps", 0)
        print()
        print(f"  Effective audio/optimizer step: ~{eff_audio_per_step:.0f}s ({eff_audio_per_step/60:.1f} min)")
        print(f"  Steps per epoch (rank {rank}): ~{steps_per_epoch:,}")
        print(f"  Total audio: {total_audio_h:.0f}h")
        print()

    # -----------------------------------------------------------------------
    # Build DataLoader
    # -----------------------------------------------------------------------
    train_dl = torch.utils.data.DataLoader(
        train_dataset,
        batch_sampler=sampler,
        num_workers=num_workers,
        collate_fn=collate_nemo_batch,
        pin_memory=True,
        persistent_workers=num_workers > 0,
        prefetch_factor=4 if num_workers > 0 else None,
    )
    model._train_dl = train_dl
    if rank == 0:
        print(f"DataLoader ready: {len(train_dl):,} micro-batches (rank 0)")
        print()
        print("=" * 60)
        print("  Starting training")
        print("=" * 60)
        print()

    t0 = time.time()
    trainer.fit(model)
    elapsed = time.time() - t0

    if rank == 0:
        steps_done = trainer.global_step
        avg_step = elapsed / max(steps_done, 1)

        vram_peak = torch.cuda.max_memory_allocated() / 1e9
        vram_reserved = torch.cuda.memory_reserved() / 1e9

        print()
        print("=" * 60)
        print("  Training complete!")
        print("=" * 60)
        print(f"  Steps: {steps_done}")
        print(f"  Elapsed: {elapsed:.1f}s ({elapsed/3600:.1f}h)")
        print(f"  Avg sec/step: {avg_step:.2f}")
        print(f"  Opt steps/s: {1.0/avg_step:.3f}")
        print(f"  Peak VRAM (rank 0): {vram_peak:.1f} GB (reserved: {vram_reserved:.1f} GB)")
        if steps_done > 0:
            remaining_steps = cfg.trainer.max_steps - steps_done
            eta_h = remaining_steps * avg_step / 3600
            print(f"  ETA for remaining {remaining_steps:,} steps: {eta_h:.1f}h ({eta_h/24:.1f} days)")
            print()
            print(f"  --- Throughput Summary ---")
            print(f"  Config: max_batch_dur={max_batch_dur}s, grad_accum={grad_accum}, devices={world_size}")
            eff_audio = max_batch_dur * world_size * grad_accum
            print(f"  Effective audio/opt step: ~{eff_audio:.0f}s ({eff_audio/60:.1f} min)")
            total_audio_h = train_dataset.durations.sum() / 3600
            print(f"  Total training audio: {total_audio_h:.0f}h")
            epochs_for_200k = (200000 * avg_step) / 3600 / 24
            print(f"  Full 200K steps would take: ~{epochs_for_200k:.1f} days")


if __name__ == "__main__":
    main()
