# coding=utf-8
# Copyright 2026 The Alibaba Qwen team.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import io
import json
import math
import os
import random
import re
import shutil
import time
from collections import OrderedDict

# Use file_system sharing to avoid /dev/shm exhaustion with many workers
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Tuple

import numpy as np
import pyarrow.parquet as pq
import soundfile as sf
import torch
import torchaudio
from qwen_asr import Qwen3ASRModel
from torch.utils.data import DataLoader, IterableDataset, get_worker_info
from transformers import (GenerationConfig, Trainer, TrainerCallback,
                          TrainingArguments)


LANG_CODE_TO_NAME = {
    "en": "English",
    "zh": "Chinese",
    "yue": "Cantonese",
    "ar": "Arabic",
    "de": "German",
    "fr": "French",
    "es": "Spanish",
    "pt": "Portuguese",
    "id": "Indonesian",
    "it": "Italian",
    "ko": "Korean",
    "ru": "Russian",
    "th": "Thai",
    "vi": "Vietnamese",
    "ja": "Japanese",
    "tr": "Turkish",
    "hi": "Hindi",
    "ms": "Malay",
    "nl": "Dutch",
    "sv": "Swedish",
    "da": "Danish",
    "fi": "Finnish",
    "pl": "Polish",
    "cs": "Czech",
    "fil": "Filipino",
    "fa": "Persian",
    "el": "Greek",
    "ro": "Romanian",
    "hu": "Hungarian",
    "mk": "Macedonian",
    # Indic languages
    "as": "Assamese",
    "bn": "Bengali",
    "gu": "Gujarati",
    "kn": "Kannada",
    "ml": "Malayalam",
    "mr": "Marathi",
    "or": "Odia",
    "pa": "Punjabi",
    "ta": "Tamil",
    "te": "Telugu",
}

_CKPT_RE = re.compile(r"^checkpoint-(\d+)$")

# Module-level batch stats for ProfilingCallback / crash breadcrumbs.
_BATCH_STATS = {
    "batch_size": 0,
    "audio_secs": 0.0,
    "seq_len": 0,
    "attn_tokens_sum": 0,
    "audio_tokens_max": 0,
    "label_tokens_sum": 0,
    "skip_reason": "",
    "debug_json": "",
}

# Markup tags that leak into native transcripts
_MARKUP_RE = re.compile(r'\[(?:laugh|cough|singing|breath|noise|cry|sigh|music|applause|beep|হাসি)\]'
                         r'|<(?:laugh|breath|singing|articulated|reduced|inaudible)>', re.IGNORECASE)

# Unicode script ranges for Indic language detection
_SCRIPT_TO_LANG = {
    'DEVANAGARI': 'hi',
    'BENGALI': 'bn',
    'GUJARATI': 'gu',
    'GURMUKHI': 'pa',
    'KANNADA': 'kn',
    'MALAYALAM': 'ml',
    'ORIYA': 'or',
    'TAMIL': 'ta',
    'TELUGU': 'te',
}

_LANG_VALID_SCRIPTS = {
    'hi': {'DEVANAGARI'},
    'mr': {'DEVANAGARI'},
    'bn': {'BENGALI'},
    'as': {'BENGALI'},
    'gu': {'GUJARATI'},
    'pa': {'GURMUKHI'},
    'kn': {'KANNADA'},
    'ml': {'MALAYALAM'},
    'or': {'ORIYA'},
    'ta': {'TAMIL'},
    'te': {'TELUGU'},
    'en': {'LATIN'},
}


def clean_transcript(text: str) -> str:
    """Remove audio markup tags from transcript text."""
    return _MARKUP_RE.sub('', text).strip()


def detect_script_language(text: str, declared_lang: str) -> str:
    """If the transcript's dominant script doesn't match declared language, return the correct lang code."""
    import unicodedata
    scripts = {}
    for ch in text:
        if ch.isalpha():
            name = unicodedata.name(ch, '')
            if name:
                script = name.split()[0]
                if script in _SCRIPT_TO_LANG or script == 'LATIN':
                    scripts[script] = scripts.get(script, 0) + 1
    if not scripts:
        return declared_lang

    dominant = max(scripts, key=scripts.get)
    valid = _LANG_VALID_SCRIPTS.get(declared_lang, set())
    if valid and dominant not in valid:
        return _SCRIPT_TO_LANG.get(dominant, declared_lang)
    return declared_lang


def patch_outer_forward(model):
    cls = model.__class__
    if getattr(cls, "_forward_patched", False):
        return

    if not hasattr(model, "thinker") or not hasattr(model.thinker, "forward"):
        raise RuntimeError(
            "Cannot patch forward: model has no `.thinker.forward`. "
            "Your qwen3_asr model may be incompatible."
        )

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        input_features=None,
        feature_attention_mask=None,
        labels=None,
        **kwargs,
    ):
        return self.thinker.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            input_features=input_features,
            feature_attention_mask=feature_attention_mask,
            labels=labels,
            **kwargs,
        )

    cls.forward = forward
    cls._forward_patched = True


def find_latest_checkpoint(output_dir: str) -> Optional[str]:
    if not output_dir or not os.path.isdir(output_dir):
        return None
    best_step = None
    best_path = None
    for name in os.listdir(output_dir):
        m = _CKPT_RE.match(name)
        if not m:
            continue
        step = int(m.group(1))
        path = os.path.join(output_dir, name)
        if os.path.isdir(path) and (best_step is None or step > best_step):
            best_step = step
            best_path = path
    return best_path


def get_dist_info() -> Tuple[int, int]:
    if torch.distributed.is_available() and torch.distributed.is_initialized():
        return torch.distributed.get_rank(), torch.distributed.get_world_size()
    return 0, 1


def _as_str(x: Any) -> str:
    if x is None:
        return ""
    return str(x)


def _get_feat_extract_output_lengths(input_lengths: torch.Tensor) -> torch.Tensor:
    input_lengths_leave = input_lengths % 100
    feat_lengths = (input_lengths_leave - 1) // 2 + 1
    return ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13


# ---------------------------------------------------------------------------
# IndexedTarReader: O(1) seek+read using pre-built .index.json files
# Replaces TarAudioLRUCache (~400x faster per file read)
# ---------------------------------------------------------------------------

class IndexedTarReader:
    """Direct byte-offset tar reader using pre-built index files.

    ~400x faster than tarfile.extractfile() because:
    - No tarfile.getmembers() scan (1.4s -> 0ms)
    - Plain file handle seek+read instead of tarfile extraction
    - LRU cache of open file handles
    """

    def __init__(self, max_open_files: int = 64, max_cached_indices: Optional[int] = None):
        self.max_open_files = int(max(1, max_open_files))
        # Loading one .index.json per tar is fast, but each index is ~1-2 MB in this
        # dataset. Keep it LRU-bounded or workers will OOM as they touch new shards.
        self.max_cached_indices = int(max(1, max_cached_indices or max_open_files))
        self._file_cache: OrderedDict = OrderedDict()  # real tar_path -> file handle
        self._index_cache: OrderedDict = OrderedDict()  # real tar_path -> {member_name: {offset, size}}

    def _load_index(self, tar_path: str) -> dict:
        real_path = os.path.realpath(tar_path)
        if real_path in self._index_cache:
            self._index_cache.move_to_end(real_path)
            return self._index_cache[real_path]

        index_path = real_path + ".index.json"
        with open(index_path, "r") as f:
            index = json.load(f)
        self._index_cache[real_path] = index
        if len(self._index_cache) > self.max_cached_indices:
            self._index_cache.popitem(last=False)
        return index

    def _get_file(self, tar_path: str):
        real_path = os.path.realpath(tar_path)
        if real_path in self._file_cache:
            self._file_cache.move_to_end(real_path)
            return self._file_cache[real_path]

        fh = open(real_path, "rb")
        self._file_cache[real_path] = fh
        if len(self._file_cache) > self.max_open_files:
            _, old_fh = self._file_cache.popitem(last=False)
            old_fh.close()
        return fh

    def read_member(self, tar_path: str, member_name: str) -> bytes:
        index = self._load_index(tar_path)

        entry = index.get(member_name)
        if entry is None:
            bare = member_name.lstrip("./")
            entry = index.get(bare) or index.get("./" + member_name)
        if entry is None:
            raise FileNotFoundError(f"Member {member_name} not found in index for {tar_path}")

        fh = self._get_file(tar_path)
        fh.seek(entry["offset"])
        return fh.read(entry["size"])

    def close(self):
        for fh in self._file_cache.values():
            fh.close()
        self._file_cache.clear()
        self._index_cache.clear()


# Keep legacy TarAudioLRUCache as fallback when indices don't exist
class TarAudioLRUCache:
    def __init__(self, max_open_tars: int = 16):
        self.max_open_tars = int(max(1, max_open_tars))
        self._cache: "OrderedDict[str, Any]" = OrderedDict()

    def close(self):
        import tarfile as _tarfile
        for tf in self._cache.values():
            try:
                tf.close()
            except Exception:
                pass
        self._cache.clear()

    def _candidate_member_names(self, name: str) -> List[str]:
        base = name.lstrip("./")
        out = [name]
        if not name.startswith("./"):
            out.append(f"./{name}")
        if base != name:
            out.append(base)
        return out

    def _get_tar(self, tar_path: str):
        import tarfile as _tarfile
        real_path = os.path.realpath(tar_path)
        tf = self._cache.get(real_path)
        if tf is not None:
            self._cache.move_to_end(real_path)
            return tf
        tf = _tarfile.open(real_path, mode="r:")
        self._cache[real_path] = tf
        if len(self._cache) > self.max_open_tars:
            _, old_tf = self._cache.popitem(last=False)
            try:
                old_tf.close()
            except Exception:
                pass
        return tf

    def read_member(self, tar_path: str, member_name: str) -> bytes:
        tf = self._get_tar(tar_path)
        for cand in self._candidate_member_names(member_name):
            try:
                fobj = tf.extractfile(cand)
            except KeyError:
                fobj = None
            if fobj is not None:
                return fobj.read()
        raise FileNotFoundError(f"Member not found: {member_name} in {tar_path}")


class HybridTarReader:
    """Uses IndexedTarReader when index exists, falls back to TarAudioLRUCache."""

    def __init__(self, max_open_files: int = 64):
        self._indexed = IndexedTarReader(max_open_files=max_open_files)
        self._legacy = TarAudioLRUCache(max_open_tars=max(1, max_open_files // 4))
        self._has_index: Dict[str, bool] = {}

    def _check_index(self, tar_path: str) -> bool:
        real_path = os.path.realpath(tar_path)
        if real_path not in self._has_index:
            self._has_index[real_path] = os.path.exists(real_path + ".index.json")
        return self._has_index[real_path]

    def read_member(self, tar_path: str, member_name: str) -> bytes:
        if self._check_index(tar_path):
            return self._indexed.read_member(tar_path, member_name)
        return self._legacy.read_member(tar_path, member_name)

    def close(self):
        self._indexed.close()
        self._legacy.close()


def _make_tar_reader(max_open: int = 64, use_indexed: bool = True) -> Any:
    """Create the best available tar reader."""
    if use_indexed:
        return HybridTarReader(max_open_files=max_open)
    return TarAudioLRUCache(max_open_tars=max_open)


def _decode_audio_bytes(raw: bytes, target_sr: int = 16000) -> Optional[np.ndarray]:
    """Decode raw audio bytes to float32 numpy array at target_sr."""
    try:
        wav, sr = sf.read(io.BytesIO(raw), dtype="float32", always_2d=False)
        wav = np.asarray(wav, dtype=np.float32)
        if wav.ndim == 2:
            wav = wav.mean(axis=1).astype(np.float32)
        if sr != target_sr:
            wav_t = torch.from_numpy(wav).unsqueeze(0)
            wav_t = torchaudio.functional.resample(wav_t, sr, target_sr)
            wav = wav_t.squeeze(0).numpy()
        return wav
    except Exception:
        return None


# ---------------------------------------------------------------------------
# BucketedIterableDataset: reads per-bucket parquets, yields decoded waveforms
# Worker-side decode: each worker decodes audio, collator only pads+tokenizes
# ---------------------------------------------------------------------------

class BucketedIterableDataset(IterableDataset):
    """Reads from per-bucket parquet files, decodes audio in workers.

    Yields PRE-FORMED BATCHES (lists of sample dicts) with per-bucket batch sizes.
    The DataLoader should use batch_size=1 and the collator unwraps the inner list.

    Key features:
    - Same-duration batching (zero cross-bucket padding)
    - Dynamic per-bucket batch sizes (short audio = big BS, long audio = small BS)
    - Worker-side audio decoding (parallelized across num_workers)
    - IndexedTarReader for O(1) file access
    - Weighted bucket sampling proportional to audio hours
    """

    # Default per-bucket batch sizes (calibrated for H200 143GB, no GC)
    # Memory: ~0.77 GB per (sample × audio_second). Target ≤125GB worst-case peak.
    # OOM at b_10_15 BS=16 (152GB needed). Calibrated with 18GB safety margin.
    DEFAULT_BUCKET_BS = {
        "b_0_3":  56,   # 0-3s: worst 56×3×0.77=129GB
        "b_3_5":  32,   # 3-5s: worst 32×5×0.77=123GB
        "b_5_7":  24,   # 5-7s: worst 24×7×0.77=129GB
        "b_7_10": 16,   # 7-10s: worst 16×10×0.77=123GB (34% of hours)
        "b_10_15": 10,  # 10-15s: worst 10×15×0.77=116GB (30% of hours)
        "b_15_20": 8,   # 15-20s: worst 8×20×0.77=123GB
        "b_20_30": 4,   # 20-30s: worst 4×30×0.77=92GB
    }

    def __init__(
        self,
        bucket_dir: str,
        bucket_config: dict,
        prompt: str = "",
        seed: int = 42,
        manifest_batch_rows: int = 65536,
        shuffle_row_groups: bool = True,
        shuffle_within_batch: bool = True,
        skip_empty_transcript: bool = True,
        max_open_tars: int = 64,
        use_indexed_tar: bool = True,
        target_sr: int = 16000,
        bucket_bs: Optional[Dict[str, int]] = None,
        default_batch_size: int = 16,
    ):
        self.bucket_dir = bucket_dir
        self.prompt = prompt
        self.seed = int(seed)
        self.manifest_batch_rows = int(manifest_batch_rows)
        self.shuffle_row_groups = bool(shuffle_row_groups)
        self.shuffle_within_batch = bool(shuffle_within_batch)
        self.skip_empty_transcript = bool(skip_empty_transcript)
        self.max_open_tars = max_open_tars
        self.use_indexed_tar = use_indexed_tar
        self.target_sr = target_sr
        self.default_batch_size = default_batch_size
        self.epoch = 0

        # Per-bucket batch sizes
        _bucket_bs = dict(self.DEFAULT_BUCKET_BS)
        if bucket_bs:
            _bucket_bs.update(bucket_bs)

        # Parse bucket config
        self.buckets = []
        total_samples = 0
        for b in bucket_config["buckets"]:
            path = os.path.join(bucket_dir, f"{b['bucket_id']}.parquet")
            if not os.path.exists(path):
                continue
            bid = b["bucket_id"]
            self.buckets.append({
                "id": bid,
                "path": path,
                "weight": b.get("hours", b.get("pct_hours", 1.0)),
                "samples": b["samples"],
                # Support both {bid: int} and {bid: {batch_size: int}} formats
                "batch_size": (lambda v: v["batch_size"] if isinstance(v, dict) else v)(
                    _bucket_bs.get(bid, default_batch_size)
                ),
            })
            total_samples += b["samples"]

        if not self.buckets:
            raise ValueError(
                f"No bucket parquet files found in {bucket_dir}. "
                "Expected files like b_0_3.parquet alongside bucket_config.json."
            )

        self._estimated_num_rows = total_samples

        # Normalize weights
        total_weight = sum(b["weight"] for b in self.buckets)
        for b in self.buckets:
            b["norm_weight"] = b["weight"] / total_weight if total_weight > 0 else 1.0 / len(self.buckets)

    @property
    def estimated_num_rows(self) -> int:
        return self._estimated_num_rows

    def set_epoch(self, epoch: int):
        self.epoch = int(epoch)

    def __iter__(self):
        """Yields pre-formed batches with dynamic per-bucket batch sizes.

        Each yielded item is a list of sample dicts from a single bucket.
        DataLoader should use batch_size=1 and collator unwraps the inner list.
        """
        worker_info = get_worker_info()
        worker_id = worker_info.id if worker_info is not None else 0
        num_workers = worker_info.num_workers if worker_info is not None else 1
        rank, world_size = get_dist_info()
        shard_id = rank * num_workers + worker_id
        num_shards = world_size * num_workers

        # Shared seed across shards keeps bucket selection order deterministic.
        # This avoids cross-rank step drift when per-bucket BS differs.
        rng = random.Random(self.seed + self.epoch * 1000003)

        tar_reader = _make_tar_reader(self.max_open_tars, self.use_indexed_tar)

        columns = ["sample_id", "language", "transcript", "tar_path",
                    "tar_member_name", "duration_s", "bucket_id"]

        try:
            bucket_iters = []
            for b in self.buckets:
                bucket_iters.append({
                    "info": b,
                    "iter": self._iter_bucket(
                        b["path"], columns, shard_id, num_shards, tar_reader
                    ),
                    "exhausted": False,
                    "buffer": [],
                })

            weights = [b["info"]["norm_weight"] for b in bucket_iters]

            while True:
                active = [(i, bi) for i, bi in enumerate(bucket_iters) if not bi["exhausted"]]
                if not active:
                    break

                active_weights = [weights[i] for i, _ in active]
                total_w = sum(active_weights)
                if total_w <= 0:
                    break

                chosen_idx = rng.choices([i for i, _ in active], weights=active_weights, k=1)[0]
                bi = bucket_iters[chosen_idx]
                target_bs = int(max(1, bi["info"]["batch_size"]))

                # Fill buffer for this bucket to target dynamic BS.
                while len(bi["buffer"]) < target_bs:
                    sample = next(bi["iter"], None)
                    if sample is None:
                        bi["exhausted"] = True
                        break
                    bi["buffer"].append(sample)

                if bi["buffer"]:
                    batch = bi["buffer"][:target_bs]
                    bi["buffer"] = bi["buffer"][target_bs:]
                    yield batch

            # Flush partial tail batches.
            for bi in bucket_iters:
                if bi["buffer"]:
                    yield bi["buffer"]
                    bi["buffer"] = []

        finally:
            tar_reader.close()

    def _iter_bucket(self, parquet_path, columns, shard_id, num_shards, tar_reader):
        """Iterate a single bucket's parquet, yielding decoded samples.

        Sharding is done at the SAMPLE level (not row-group level) because
        bucket parquets may have fewer row groups than total shards (GPUs × workers).
        Every worker reads all row groups but only processes every num_shards-th sample.

        IMPORTANT: Row group order and within-batch shuffle must be IDENTICAL across
        all shards so that global_sample_idx maps to the same physical sample everywhere.
        We use a SHARED seed (not shard-specific) for these shuffles.
        """
        pf = pq.ParquetFile(parquet_path)
        num_rg = pf.num_row_groups

        row_groups = list(range(num_rg))
        # Shared RNG across all shards for deterministic row group order
        shared_rng = random.Random(self.seed + self.epoch * 1000003)
        if self.shuffle_row_groups:
            shared_rng.shuffle(row_groups)

        global_sample_idx = 0
        for rg_idx in row_groups:
            batches = pf.iter_batches(
                batch_size=self.manifest_batch_rows,
                row_groups=[rg_idx],
                columns=columns,
                use_threads=True,
            )
            for batch in batches:
                data = {name: batch.column(i).to_pylist()
                        for i, name in enumerate(batch.schema.names)}
                n = batch.num_rows
                indices = list(range(n))
                if self.shuffle_within_batch:
                    shared_rng.shuffle(indices)  # shared across shards for consistent global_sample_idx

                for i in indices:
                    # Sample-level sharding: each worker processes every num_shards-th sample
                    if global_sample_idx % num_shards != shard_id:
                        global_sample_idx += 1
                        continue
                    global_sample_idx += 1

                    transcript = _as_str(data["transcript"][i]).strip()
                    if self.skip_empty_transcript and not transcript:
                        continue

                    tar_path = _as_str(data["tar_path"][i])
                    tar_member_name = _as_str(data["tar_member_name"][i])
                    if not tar_path or not tar_member_name:
                        continue

                    # Clean transcript
                    transcript = clean_transcript(transcript)
                    if not transcript:
                        continue

                    # Detect script language
                    declared_lang = _as_str(data["language"][i])
                    effective_lang = detect_script_language(transcript, declared_lang)

                    # WORKER-SIDE AUDIO DECODE
                    try:
                        raw_bytes = tar_reader.read_member(tar_path, tar_member_name)
                        wav = _decode_audio_bytes(raw_bytes, self.target_sr)
                        if wav is None or len(wav) == 0:
                            continue
                    except Exception:
                        continue

                    duration_s = data["duration_s"][i]
                    bucket_id = _as_str(data.get("bucket_id", [""])[i]) if "bucket_id" in data else ""

                    yield {
                        "sample_id": _as_str(data["sample_id"][i]),
                        "waveform": wav,
                        "transcript": transcript,
                        "language": effective_lang,
                        "prompt": self.prompt,
                        "duration_s": float(duration_s) if duration_s is not None else len(wav) / self.target_sr,
                        "bucket_id": bucket_id,
                        "tar_path": tar_path,
                        "tar_member_name": tar_member_name,
                    }


# Legacy dataset for non-bucketed mode (backward compatible)
class ParquetTarIterableDataset(IterableDataset):
    def __init__(
        self,
        parquet_path: str,
        split: str = "",
        prompt: str = "",
        seed: int = 42,
        manifest_batch_rows: int = 65536,
        shuffle_row_groups: bool = True,
        shuffle_within_batch: bool = True,
        min_duration_s: float = 0.0,
        max_duration_s: float = 1200.0,
        skip_empty_transcript: bool = True,
        max_samples: int = 0,
        max_open_tars: int = 64,
        use_indexed_tar: bool = True,
        target_sr: int = 16000,
        worker_decode: bool = True,
    ):
        self.parquet_path = parquet_path
        self.split = split.strip()
        self.prompt = prompt
        self.seed = int(seed)
        self.manifest_batch_rows = int(manifest_batch_rows)
        self.shuffle_row_groups = bool(shuffle_row_groups)
        self.shuffle_within_batch = bool(shuffle_within_batch)
        self.min_duration_s = float(min_duration_s)
        self.max_duration_s = float(max_duration_s)
        self.skip_empty_transcript = bool(skip_empty_transcript)
        self.max_samples = int(max_samples)
        self.max_open_tars = max_open_tars
        self.use_indexed_tar = use_indexed_tar
        self.target_sr = target_sr
        self.worker_decode = worker_decode
        self.epoch = 0

        pf = pq.ParquetFile(self.parquet_path)
        self._estimated_num_rows = int(pf.metadata.num_rows)

    @property
    def estimated_num_rows(self) -> int:
        return self._estimated_num_rows

    def set_epoch(self, epoch: int):
        self.epoch = int(epoch)

    def _iter_row_groups_for_shard(self, num_row_groups: int) -> Iterable[int]:
        row_groups = list(range(num_row_groups))

        worker_info = get_worker_info()
        worker_id = worker_info.id if worker_info is not None else 0
        num_workers = worker_info.num_workers if worker_info is not None else 1
        rank, world_size = get_dist_info()
        shard_id = rank * num_workers + worker_id
        num_shards = world_size * num_workers

        # SHARED seed across all shards so shuffle produces identical ordering
        rng = random.Random(self.seed + self.epoch * 1000003)
        if self.shuffle_row_groups:
            rng.shuffle(row_groups)

        for idx, rg in enumerate(row_groups):
            if idx % num_shards == shard_id:
                yield rg

    def __iter__(self):
        pf = pq.ParquetFile(self.parquet_path)
        columns = [
            "sample_id", "language", "transcript",
            "tar_path", "tar_member_name", "duration_s", "split",
        ]

        worker_info = get_worker_info()
        worker_id = worker_info.id if worker_info is not None else 0
        num_workers = worker_info.num_workers if worker_info is not None else 1
        rank, world_size = get_dist_info()
        shard_seed = self.seed + self.epoch * 1000003 + rank * 10007 + worker_id * 1009
        rng = random.Random(shard_seed)

        # Each worker gets its own tar reader
        tar_reader = _make_tar_reader(self.max_open_tars, self.use_indexed_tar) if self.worker_decode else None

        try:
            yielded = 0
            for rg_idx in self._iter_row_groups_for_shard(pf.num_row_groups):
                batches = pf.iter_batches(
                    batch_size=self.manifest_batch_rows,
                    row_groups=[rg_idx],
                    columns=columns,
                    use_threads=True,
                )
                for batch in batches:
                    data = {name: batch.column(i).to_pylist() for i, name in enumerate(batch.schema.names)}
                    n = batch.num_rows
                    indices = list(range(n))
                    if self.shuffle_within_batch:
                        rng.shuffle(indices)

                    for i in indices:
                        row_split = _as_str(data.get("split", [""] * n)[i])
                        if self.split and row_split and row_split != self.split:
                            continue

                        transcript = _as_str(data["transcript"][i]).strip()
                        if self.skip_empty_transcript and not transcript:
                            continue

                        duration_s = data["duration_s"][i]
                        if duration_s is not None:
                            duration_s = float(duration_s)
                            if duration_s < self.min_duration_s or duration_s > self.max_duration_s:
                                continue

                        tar_path = _as_str(data["tar_path"][i])
                        tar_member_name = _as_str(data["tar_member_name"][i])
                        if not tar_path or not tar_member_name:
                            continue

                        # Clean transcript
                        transcript = clean_transcript(transcript)
                        if not transcript:
                            continue

                        # Detect script language
                        declared_lang = _as_str(data["language"][i])
                        effective_lang = detect_script_language(transcript, declared_lang)

                        if self.worker_decode and tar_reader is not None:
                            # WORKER-SIDE DECODE
                            try:
                                raw_bytes = tar_reader.read_member(tar_path, tar_member_name)
                                wav = _decode_audio_bytes(raw_bytes, self.target_sr)
                                if wav is None or len(wav) == 0:
                                    continue
                            except Exception:
                                continue

                            yield {
                                "waveform": wav,
                                "transcript": transcript,
                                "language": effective_lang,
                                "prompt": self.prompt,
                                "duration_s": float(duration_s) if duration_s is not None else len(wav) / self.target_sr,
                            }
                        else:
                            # Legacy mode: collator decodes
                            yield {
                                "sample_id": _as_str(data["sample_id"][i]),
                                "language": effective_lang,
                                "transcript": transcript,
                                "prompt": self.prompt,
                                "tar_path": tar_path,
                                "tar_member_name": tar_member_name,
                                "duration_s": float(duration_s) if duration_s is not None else None,
                            }

                        yielded += 1
                        if self.max_samples > 0 and yielded >= self.max_samples:
                            return
        finally:
            if tar_reader is not None:
                tar_reader.close()


def build_prefix_messages(prompt: str, audio_array):
    return [
        {"role": "system", "content": prompt or ""},
        {"role": "user", "content": [{"type": "audio", "audio": audio_array}]},
    ]


def format_target_text(
    transcript: str,
    language_code: str,
    language_tag_mode: str,
    transcript_has_prefix: bool,
) -> str:
    transcript = transcript.strip()
    if transcript_has_prefix:
        return transcript

    mode = language_tag_mode
    code = (language_code or "").strip().lower()
    if mode == "none":
        lang_tag = "None"
    elif mode == "code":
        lang_tag = code if code else "None"
    elif mode == "name":
        lang_tag = LANG_CODE_TO_NAME.get(code, "None")
    else:
        # auto
        lang_tag = LANG_CODE_TO_NAME.get(code, "None")
    return f"language {lang_tag}<asr_text>{transcript}"


@dataclass
class DataCollatorForQwen3ASRPhase2:
    """Collator that handles both worker-decoded waveforms and legacy tar-path mode."""
    processor: Any
    sampling_rate: int = 16000
    max_open_tars: int = 16
    language_tag_mode: str = "auto"
    transcript_has_prefix: bool = False
    skip_decode_errors: bool = True
    worker_decode: bool = True
    max_batch_seq_len: int = 500

    def __post_init__(self):
        # Only create tar cache if collator needs to decode (legacy mode)
        if not self.worker_decode:
            self._tar_cache = TarAudioLRUCache(max_open_tars=self.max_open_tars)
        self._prefix_cache: Dict[str, str] = {}

    def _make_zero_loss_batch(
        self,
        batch_debug: Optional[Dict[str, Any]] = None,
        skip_reason: str = "",
    ) -> Dict[str, Any]:
        """Return a dummy batch that contributes zero loss but preserves breadcrumbs."""
        debug = dict(batch_debug or {})
        dummy_text = self._build_prefix_text("") + "language None<asr_text>dummy" + (
            self.processor.tokenizer.eos_token or ""
        )
        dummy_wav = np.zeros(16000, dtype=np.float32)
        full_inputs = self.processor(
            text=[dummy_text],
            audio=[dummy_wav],
            return_tensors="pt",
            padding=True,
            truncation=False,
        )
        labels = full_inputs["input_ids"].clone()
        full_inputs["labels"] = labels
        full_inputs["loss_scale"] = torch.tensor(0.0, dtype=torch.float32)
        full_inputs["_batch_size"] = torch.tensor(int(debug.get("sample_count", 0)), dtype=torch.long)
        full_inputs["_batch_audio_secs"] = torch.tensor(
            float(debug.get("audio_secs", 0.0)), dtype=torch.float32
        )
        full_inputs["_batch_seq_len"] = torch.tensor(int(debug.get("seq_len", 0)), dtype=torch.long)
        full_inputs["_batch_attn_tokens_sum"] = torch.tensor(
            int(debug.get("attn_tokens_sum", 0)), dtype=torch.long
        )
        full_inputs["_batch_audio_tokens_max"] = torch.tensor(
            int(debug.get("audio_tokens_max", 0)), dtype=torch.long
        )
        full_inputs["_batch_label_tokens_sum"] = torch.tensor(
            int(debug.get("label_tokens_sum", 0)), dtype=torch.long
        )
        full_inputs["_batch_skip_reason"] = skip_reason
        full_inputs["_batch_debug_json"] = json.dumps(
            debug, ensure_ascii=True, separators=(",", ":")
        )
        return full_inputs

    def _build_prefix_text(self, prompt: str) -> str:
        if prompt in self._prefix_cache:
            return self._prefix_cache[prompt]
        prefix_msgs = build_prefix_messages(prompt, None)
        prefix_text = self.processor.apply_chat_template(
            [prefix_msgs], add_generation_prompt=True, tokenize=False
        )[0]
        self._prefix_cache[prompt] = prefix_text
        return prefix_text

    def _decode_audio_legacy(self, tar_path: str, member_name: str) -> np.ndarray:
        """Legacy: decode audio from tar in collator (main process)."""
        raw = self._tar_cache.read_member(tar_path, member_name)
        wav, sr = sf.read(io.BytesIO(raw), dtype="float32", always_2d=False)
        wav = np.asarray(wav, dtype=np.float32)
        if wav.ndim == 2:
            wav = wav.mean(axis=1).astype(np.float32)
        if sr != self.sampling_rate:
            wav_t = torch.from_numpy(wav).unsqueeze(0)
            wav_t = torchaudio.functional.resample(wav_t, sr, self.sampling_rate)
            wav = wav_t.squeeze(0).numpy()
        return wav

    def __call__(self, features: List[Any]) -> Dict[str, torch.Tensor]:
        # Bucketed dataset emits pre-formed batches. DataLoader wraps that once
        # when batch_size=1, so unwrap here.
        if features and isinstance(features[0], list):
            features = features[0]

        audios: List[np.ndarray] = []
        prefix_texts: List[str] = []
        targets: List[str] = []
        sample_debug: List[Dict[str, Any]] = []

        for f in features:
            if self.worker_decode and "waveform" in f:
                # Worker already decoded the audio
                wav = f["waveform"]
                if wav is None or len(wav) == 0:
                    continue
                transcript = f.get("transcript", "").strip()
                language = f.get("language", "")
            else:
                # Legacy: decode in collator
                try:
                    wav = self._decode_audio_legacy(f["tar_path"], f["tar_member_name"])
                except Exception:
                    if self.skip_decode_errors:
                        continue
                    raise
                transcript = clean_transcript(_as_str(f.get("transcript", "")))
                if not transcript:
                    continue
                declared_lang = _as_str(f.get("language", ""))
                language = detect_script_language(transcript, declared_lang)

            if not transcript:
                continue

            # Sanitize audio: skip samples with NaN/Inf or extreme length
            wav_arr = np.asarray(wav)
            if np.isnan(wav_arr).any() or np.isinf(wav_arr).any():
                import warnings
                sid = f.get("sample_id", "?")
                warnings.warn(f"Skipping sample {sid}: audio contains NaN/Inf")
                continue
            if len(wav_arr) > 30 * self.sampling_rate:  # >30s at sample rate
                import warnings
                sid = f.get("sample_id", "?")
                warnings.warn(f"Skipping sample {sid}: audio too long ({len(wav_arr)/self.sampling_rate:.1f}s)")
                continue

            target = format_target_text(
                transcript=transcript,
                language_code=language,
                language_tag_mode=self.language_tag_mode,
                transcript_has_prefix=self.transcript_has_prefix,
            )

            prompt = _as_str(f.get("prompt", ""))
            prefix_text = self._build_prefix_text(prompt)
            duration_s = f.get("duration_s")
            if duration_s is None:
                duration_s = len(wav_arr) / self.sampling_rate

            audios.append(wav)
            prefix_texts.append(prefix_text)
            targets.append(target)
            sample_debug.append({
                "sample_id": _as_str(f.get("sample_id", "")),
                "tar_member_name": _as_str(f.get("tar_member_name", "")),
                "bucket_id": _as_str(f.get("bucket_id", "")),
                "duration_s": round(float(duration_s), 3),
                "transcript_len": len(transcript),
            })

        if not audios:
            import warnings
            warnings.warn("All samples in batch failed decoding/filtering — returning dummy batch")
            return self._make_zero_loss_batch(
                batch_debug={"sample_count": 0, "samples": sample_debug},
                skip_reason="all_samples_filtered",
            )

        eos = self.processor.tokenizer.eos_token or ""
        full_texts = [pfx + tgt + eos for pfx, tgt in zip(prefix_texts, targets)]

        full_inputs = self.processor(
            text=full_texts,
            audio=audios,
            return_tensors="pt",
            padding=True,
            truncation=False,
        )

        audio_token_lens = _get_feat_extract_output_lengths(
            full_inputs["feature_attention_mask"].sum(dim=1)
        ).tolist()
        prefix_expanded = self.processor.replace_multimodal_special_tokens(
            prefix_texts,
            iter(audio_token_lens),
        )
        prefix_tok = self.processor.tokenizer(
            prefix_expanded,
            return_tensors="pt",
            padding=True,
            truncation=False,
        )
        prefix_lens = prefix_tok["attention_mask"].sum(dim=1).tolist()

        labels = full_inputs["input_ids"].clone()
        for i, pl in enumerate(prefix_lens):
            valid_token_positions = torch.nonzero(full_inputs["attention_mask"][i], as_tuple=True)[0]
            labels[i, valid_token_positions[:pl]] = -100

        pad_id = self.processor.tokenizer.pad_token_id
        if pad_id is not None:
            labels[labels == pad_id] = -100

        batch_audio_secs = sum(len(a) / self.sampling_rate for a in audios)
        batch_debug = {
            "bucket_id": ",".join(sorted({s.get("bucket_id", "") for s in sample_debug if s.get("bucket_id", "")})),
            "sample_count": len(audios),
            "audio_secs": round(batch_audio_secs, 3),
            "seq_len": int(full_inputs["input_ids"].shape[1]),
            "attn_tokens_sum": int(full_inputs["attention_mask"].sum().item()),
            "audio_tokens_max": int(max(audio_token_lens) if audio_token_lens else 0),
            "label_tokens_sum": int((labels != -100).sum().item()),
            "samples": sorted(
                sample_debug,
                key=lambda x: (x.get("transcript_len", 0), x.get("duration_s", 0.0)),
                reverse=True,
            ),
        }
        if self.max_batch_seq_len > 0 and batch_debug["seq_len"] > self.max_batch_seq_len:
            return self._make_zero_loss_batch(
                batch_debug=batch_debug,
                skip_reason=f"seq_len {batch_debug['seq_len']} > {self.max_batch_seq_len}",
            )

        full_inputs["labels"] = labels
        full_inputs["loss_scale"] = torch.tensor(1.0, dtype=torch.float32)

        # Embed actual batch stats in the output for profiling
        # (collate_fn runs in worker processes, so module-level dicts don't propagate)
        full_inputs["_batch_size"] = torch.tensor(len(audios), dtype=torch.long)
        full_inputs["_batch_audio_secs"] = torch.tensor(batch_audio_secs, dtype=torch.float32)
        full_inputs["_batch_seq_len"] = torch.tensor(batch_debug["seq_len"], dtype=torch.long)
        full_inputs["_batch_attn_tokens_sum"] = torch.tensor(
            batch_debug["attn_tokens_sum"], dtype=torch.long
        )
        full_inputs["_batch_audio_tokens_max"] = torch.tensor(
            batch_debug["audio_tokens_max"], dtype=torch.long
        )
        full_inputs["_batch_label_tokens_sum"] = torch.tensor(
            batch_debug["label_tokens_sum"], dtype=torch.long
        )
        full_inputs["_batch_skip_reason"] = ""
        full_inputs["_batch_debug_json"] = json.dumps(
            batch_debug, ensure_ascii=True, separators=(",", ":")
        )

        return full_inputs


class CastFloatInputsTrainer(Trainer):
    def _prepare_inputs(self, inputs):
        # Extract profiling metadata before sending to model
        batch_size_t = inputs.pop("_batch_size", None)
        batch_audio_t = inputs.pop("_batch_audio_secs", None)
        batch_seq_len_t = inputs.pop("_batch_seq_len", None)
        batch_attn_sum_t = inputs.pop("_batch_attn_tokens_sum", None)
        batch_audio_tokens_t = inputs.pop("_batch_audio_tokens_max", None)
        batch_label_tokens_t = inputs.pop("_batch_label_tokens_sum", None)
        batch_skip_reason = inputs.pop("_batch_skip_reason", "")
        batch_debug_json = inputs.pop("_batch_debug_json", "")
        if batch_size_t is not None:
            _BATCH_STATS["batch_size"] = int(batch_size_t.item())
        else:
            _BATCH_STATS["batch_size"] = 0
        if batch_audio_t is not None:
            _BATCH_STATS["audio_secs"] = float(batch_audio_t.item())
        else:
            _BATCH_STATS["audio_secs"] = 0.0
        if batch_seq_len_t is not None:
            _BATCH_STATS["seq_len"] = int(batch_seq_len_t.item())
        else:
            _BATCH_STATS["seq_len"] = 0
        if batch_attn_sum_t is not None:
            _BATCH_STATS["attn_tokens_sum"] = int(batch_attn_sum_t.item())
        else:
            _BATCH_STATS["attn_tokens_sum"] = 0
        if batch_audio_tokens_t is not None:
            _BATCH_STATS["audio_tokens_max"] = int(batch_audio_tokens_t.item())
        else:
            _BATCH_STATS["audio_tokens_max"] = 0
        if batch_label_tokens_t is not None:
            _BATCH_STATS["label_tokens_sum"] = int(batch_label_tokens_t.item())
        else:
            _BATCH_STATS["label_tokens_sum"] = 0
        _BATCH_STATS["skip_reason"] = batch_skip_reason or ""
        _BATCH_STATS["debug_json"] = batch_debug_json or ""

        inputs = super()._prepare_inputs(inputs)
        model_dtype = getattr(self.model, "dtype", None)
        if model_dtype is not None:
            for k, v in list(inputs.items()):
                if torch.is_tensor(v) and v.is_floating_point():
                    inputs[k] = v.to(dtype=model_dtype)
        return inputs

    def _format_batch_debug(self) -> str:
        debug_json = _BATCH_STATS.get("debug_json") or ""
        if not debug_json:
            return ""
        try:
            debug = json.loads(debug_json)
        except Exception:
            return debug_json[:1000]

        parts = [
            f"bucket={debug.get('bucket_id') or '?'}",
            f"bs={debug.get('sample_count', '?')}",
            f"seq_len={debug.get('seq_len', '?')}",
            f"attn_sum={debug.get('attn_tokens_sum', '?')}",
            f"audio_tok_max={debug.get('audio_tokens_max', '?')}",
            f"label_tok_sum={debug.get('label_tokens_sum', '?')}",
        ]
        samples = []
        for sample in debug.get("samples", [])[:8]:
            sid = sample.get("sample_id") or sample.get("tar_member_name") or "?"
            dur = sample.get("duration_s")
            tlen = sample.get("transcript_len", "?")
            if isinstance(dur, (int, float)):
                samples.append(f"{sid}({dur:.2f}s,{tlen}c)")
            else:
                samples.append(f"{sid}({tlen}c)")
        if samples:
            parts.append("samples=[" + "; ".join(samples) + "]")
        return " ".join(parts)

    def _sanitize_inputs(self, inputs) -> str:
        """Validate and sanitize inputs before they hit the GPU.

        CUDA kernel errors from malformed tensors (NaN/Inf in features, extreme
        sequence lengths) are asynchronous and unrecoverable — the only defense
        is to catch them *before* the forward pass.

        Returns an empty string if the batch is safe, otherwise a skip reason.
        """
        # Check for NaN/Inf in float tensors
        for k, v in inputs.items():
            if not torch.is_tensor(v) or not v.is_floating_point():
                continue
            if torch.isnan(v).any() or torch.isinf(v).any():
                bad_nan = int(torch.isnan(v).sum().item())
                bad_inf = int(torch.isinf(v).sum().item())
                return (
                    f"tensor '{k}' has {bad_nan} NaN, {bad_inf} Inf "
                    f"(shape={list(v.shape)})"
                )

        # Check for extreme sequence lengths that can crash flash attention
        input_ids = inputs.get("input_ids")
        if input_ids is not None and input_ids.shape[-1] > 32768:
            return f"sequence length {input_ids.shape[-1]} > 32768"

        feat = inputs.get("input_features")
        if feat is not None and feat.ndim >= 2 and feat.shape[-1] > 300000:
            return f"feature length {feat.shape[-1]} > 300000"

        return ""

    def _sync_skip_decision(self, local_skip: bool) -> Tuple[bool, int]:
        """Synchronize rank-local skip decisions across DDP ranks."""
        local_count = 1 if local_skip else 0
        if not (torch.distributed.is_available() and torch.distributed.is_initialized()):
            return bool(local_skip), local_count

        skip_count = torch.tensor(local_count, device=self.args.device, dtype=torch.int32)
        torch.distributed.all_reduce(skip_count, op=torch.distributed.ReduceOp.SUM)
        total = int(skip_count.item())
        return total > 0, total

    def _make_ddp_safe_zero_loss(self, model) -> torch.Tensor:
        """Build a zero loss that still touches every trainable parameter.

        Returning a detached scalar on only one rank is unsafe in DDP because the
        other ranks will still all-reduce real gradients. This helper creates a
        cheap zero-valued loss connected to every parameter so all ranks take the
        same backward path with zero gradients.
        """
        zero_loss = None
        for param in model.parameters():
            if not param.requires_grad:
                continue
            term = param.reshape(-1)[:1].sum() * 0.0
            zero_loss = term if zero_loss is None else zero_loss + term
        if zero_loss is None:
            zero_loss = torch.tensor(0.0, device=self.args.device, requires_grad=True)
        return zero_loss

    def training_step(self, model, inputs, num_items_in_batch=None):
        loss_scale = inputs.pop("loss_scale", None)
        # Set static_graph on first training step (after DDP wrapping)
        if not getattr(self, '_static_graph_set', False):
            if hasattr(model, '_set_static_graph'):
                model._set_static_graph()
                if self.args.process_index == 0:
                    print("[config] DDP static_graph: ON")
            self._static_graph_set = True

        loss_scale_value = None
        if loss_scale is not None:
            if torch.is_tensor(loss_scale):
                if loss_scale.numel() > 1:
                    loss_scale_value = float(loss_scale.detach().float().mean().cpu().item())
                else:
                    loss_scale_value = float(loss_scale.detach().float().cpu().item())
            else:
                loss_scale_value = float(loss_scale)

        local_skip_reason = ""
        batch_skip_reason = _BATCH_STATS.get("skip_reason") or ""
        if batch_skip_reason and loss_scale_value == 0.0:
            local_skip_reason = batch_skip_reason

        if not local_skip_reason:
            local_skip_reason = self._sanitize_inputs(inputs)

        skip_this_step, skip_ranks = self._sync_skip_decision(bool(local_skip_reason))
        if skip_this_step:
            rank = self.args.process_index
            step = self.state.global_step
            reason = local_skip_reason or "skip requested by another rank"
            source = "local" if local_skip_reason else "peer"
            print(
                f"[BATCH_SKIP] rank={rank} step={step} source={source} "
                f"skip_ranks={skip_ranks} reason={reason} {self._format_batch_debug()}",
                flush=True,
            )
            return self._make_ddp_safe_zero_loss(model)

        try:
            loss = super().training_step(model, inputs, num_items_in_batch)
        except RuntimeError as e:
            err_msg = str(e).lower()
            if "cuda" in err_msg or "nccl" in err_msg or "device-side assert" in err_msg:
                rank = self.args.process_index
                step = self.state.global_step
                print(
                    f"[FATAL_TRAIN_ERROR] rank={rank} step={step} caught: {e} "
                    f"{self._format_batch_debug()}",
                    flush=True,
                )
            raise

        if loss_scale is not None:
            if torch.is_tensor(loss_scale):
                if loss_scale.numel() > 1:
                    loss_scale = loss_scale.float().mean()
                loss_scale = float(loss_scale.detach().float().cpu().item())
            loss = loss * float(loss_scale)
        return loss

    def _shutdown_dataloader_workers(self, dataloader):
        """Best-effort cleanup for eval DataLoader workers.

        Trainer keeps a reference to the last eval dataloader on the callback handler.
        If eval uses persistent workers, those worker processes can stay alive after
        evaluation and compete with the train workers. Recursively clean up both plain
        PyTorch dataloaders and accelerate wrappers.
        """
        if dataloader is None:
            return

        stack = [dataloader]
        seen = set()
        while stack:
            dl = stack.pop()
            if dl is None:
                continue
            dl_id = id(dl)
            if dl_id in seen:
                continue
            seen.add(dl_id)

            base_dataloader = getattr(dl, "base_dataloader", None)
            if base_dataloader is not None:
                stack.append(base_dataloader)

            iterator = getattr(dl, "_iterator", None)
            if iterator is None:
                continue

            shutdown = getattr(iterator, "_shutdown_workers", None)
            if callable(shutdown):
                try:
                    shutdown()
                except Exception:
                    pass

            try:
                dl._iterator = None
            except Exception:
                pass

    def _cleanup_eval_dataloaders(self):
        eval_loaders = []

        cb_eval_loader = getattr(self.callback_handler, "eval_dataloader", None)
        if cb_eval_loader is not None:
            eval_loaders.append(cb_eval_loader)
            self.callback_handler.eval_dataloader = None

        cached = getattr(self, "_eval_dataloaders", None)
        if cached:
            eval_loaders.extend(cached.values())
            self._eval_dataloaders = {}

        for dl in eval_loaders:
            self._shutdown_dataloader_workers(dl)

    def _make_dataloader(self, dataset, batch_size=None, is_eval: bool = False):
        """Create a plain DataLoader avoiding accelerate's dispatch_batches."""
        is_prebatched = isinstance(dataset, BucketedIterableDataset)
        bs = 1 if is_prebatched else (batch_size or self.args.per_device_train_batch_size)
        num_workers = self.args.dataloader_num_workers
        persistent_workers = self.args.dataloader_persistent_workers if num_workers > 0 else False
        prefetch_factor = self.args.dataloader_prefetch_factor if num_workers > 0 else None

        # Eval is small/infrequent. Keeping a separate worker pool alive for it has
        # been fragile with this IterableDataset + Trainer setup, so make eval
        # conservative and deterministic by default.
        if is_eval:
            num_workers = 0
            persistent_workers = False
            prefetch_factor = None

        return DataLoader(
            dataset,
            batch_size=bs,
            collate_fn=self.data_collator,
            num_workers=num_workers,
            pin_memory=self.args.dataloader_pin_memory,
            persistent_workers=persistent_workers,
            prefetch_factor=prefetch_factor,
        )

    def get_train_dataloader(self):
        return self._make_dataloader(self.train_dataset)

    def get_eval_dataloader(self, eval_dataset=None):
        ds = eval_dataset if eval_dataset is not None else self.eval_dataset
        if ds is None:
            raise ValueError("No eval dataset")
        self._cleanup_eval_dataloaders()
        return self._make_dataloader(ds, batch_size=self.args.per_device_eval_batch_size, is_eval=True)

    def evaluate(self, *args, **kwargs):
        try:
            return super().evaluate(*args, **kwargs)
        finally:
            self._cleanup_eval_dataloaders()


def copy_required_hf_files_for_qwen_asr(src_dir: str, dst_dir: str):
    os.makedirs(dst_dir, exist_ok=True)
    required = [
        "config.json",
        "generation_config.json",
        "preprocessor_config.json",
        "processor_config.json",
        "tokenizer_config.json",
        "tokenizer.json",
        "special_tokens_map.json",
        "chat_template.json",
        "merges.txt",
        "vocab.json",
    ]
    for fn in required:
        src = os.path.join(src_dir, fn)
        if os.path.exists(src):
            shutil.copy2(src, os.path.join(dst_dir, fn))


class MakeEveryCheckpointInferableCallback(TrainerCallback):
    def __init__(self, base_model_path: str):
        self.base_model_path = base_model_path

    def on_save(self, args: TrainingArguments, state, control, **kwargs):
        if args.process_index != 0:
            return control
        ckpt_dir = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
        if not os.path.isdir(ckpt_dir):
            ckpt_dir = kwargs.get("checkpoint", ckpt_dir)
        copy_required_hf_files_for_qwen_asr(self.base_model_path, ckpt_dir)
        return control


class R2MilestoneCallback(TrainerCallback):
    """Push checkpoints to R2 at milestone steps (e.g., every 50k)."""

    def __init__(self, milestone_interval: int = 50000):
        self.milestone_interval = milestone_interval
        self._pushed = set()

    def on_save(self, args: TrainingArguments, state, control, **kwargs):
        if args.process_index != 0:
            return control
        step = state.global_step
        if step % self.milestone_interval != 0 or step in self._pushed:
            return control

        ckpt_dir = os.path.join(args.output_dir, f"checkpoint-{step}")
        if not os.path.isdir(ckpt_dir):
            return control

        # Load R2 creds from .env
        env_path = "/root/data/.env"
        r2_vars = {}
        if os.path.exists(env_path):
            with open(env_path) as f:
                for line in f:
                    line = line.strip()
                    if "=" in line and not line.startswith("#"):
                        k, v = line.split("=", 1)
                        r2_vars[k.strip()] = v.strip()

        endpoint = r2_vars.get("R2_ENDPOINT_URL", "")
        access_key = r2_vars.get("R2_ACCESS_KEY_ID", "")
        secret_key = r2_vars.get("R2_SECRET_ACCESS_KEY", "")
        if not all([endpoint, access_key, secret_key]):
            print(f"[R2] Skipping push for step {step}: missing R2 credentials")
            return control

        import subprocess, datetime
        ckpt_date = datetime.datetime.now().strftime("%m-%d-%Y")
        r2_path = f"s3://ptcheckpoints/qwen3-asr-1.7B/{ckpt_date}/ckpt-{step}/"

        print(f"[R2] Pushing checkpoint-{step} to {r2_path} ...", flush=True)
        env = os.environ.copy()
        env["AWS_ACCESS_KEY_ID"] = access_key
        env["AWS_SECRET_ACCESS_KEY"] = secret_key
        env["AWS_DEFAULT_REGION"] = "auto"

        result = subprocess.run(
            ["aws", "s3", "sync", ckpt_dir, r2_path,
             "--endpoint-url", endpoint, "--no-progress"],
            env=env, capture_output=True, text=True, timeout=600,
        )
        if result.returncode == 0:
            self._pushed.add(step)
            print(f"[R2] SUCCESS: checkpoint-{step} pushed to {r2_path}", flush=True)
        else:
            print(f"[R2] FAILED: checkpoint-{step}: {result.stderr[:200]}", flush=True)

        return control


class ProfilingCallback(TrainerCallback):
    """Logs per-step timing and throughput metrics.

    Reports three throughput metrics:
    1. samples/sec (SPS) — raw sample count
    2. audio-sec/sec — seconds of audio processed per wall-clock second
    3. decoder-tok/sec — estimated decoder tokens per second
    """
    # Average durations per bucket for audio-sec/sec estimation
    AVG_DURATION_S = 7.11  # weighted average across all buckets
    AVG_DECODER_TOKENS = 382  # weighted average decoder tokens per sample

    def __init__(self):
        self._step_start = None
        self._step_times = []
        self._step_batch_sizes = []  # actual batch sizes per step
        self._step_audio_secs = []   # total audio seconds per step
        self._total_samples = 0
        self._total_audio_secs = 0.0
        self._train_start = None

    def on_train_begin(self, args, state, control, **kwargs):
        self._train_start = time.time()

    def on_step_begin(self, args, state, control, **kwargs):
        self._step_start = time.time()

    def on_log(self, args, state, control, logs=None, **kwargs):
        """Capture actual batch size from trainer logs if available."""
        pass

    def on_step_end(self, args, state, control, **kwargs):
        if self._step_start is not None:
            dt = time.time() - self._step_start
            self._step_times.append(dt)

            world_size = int(os.environ.get("WORLD_SIZE", "1"))
            # Use actual batch size from collator if available (dynamic BS)
            actual_bs = _BATCH_STATS["batch_size"]
            actual_audio = _BATCH_STATS["audio_secs"]
            if actual_bs > 0:
                grad_acc = args.gradient_accumulation_steps
                samples_this_step = actual_bs * grad_acc * world_size
                audio_secs = actual_audio * grad_acc * world_size
            else:
                bs_per_gpu = args.per_device_train_batch_size
                grad_acc = args.gradient_accumulation_steps
                samples_this_step = bs_per_gpu * grad_acc * world_size
                audio_secs = samples_this_step * self.AVG_DURATION_S

            self._step_batch_sizes.append(samples_this_step)
            self._step_audio_secs.append(audio_secs)
            self._total_samples += samples_this_step
            self._total_audio_secs += audio_secs

            if state.global_step % args.logging_steps == 0 and args.process_index == 0:
                recent_dt = self._step_times[-10:]
                recent_bs = self._step_batch_sizes[-10:]
                recent_audio = self._step_audio_secs[-10:]
                avg_step = sum(recent_dt) / len(recent_dt)
                sps = sum(recent_bs) / sum(recent_dt)
                audio_sec_per_sec = sum(recent_audio) / sum(recent_dt)
                dtok_per_sec = sps * self.AVG_DECODER_TOKENS
                elapsed = time.time() - self._train_start
                overall_sps = self._total_samples / elapsed if elapsed > 0 else 0
                peak_mem = torch.cuda.max_memory_allocated() / 1e9
                print(f"[profile] step={state.global_step} dt={avg_step:.3f}s "
                      f"SPS={sps:.0f} audio_sec/s={audio_sec_per_sec:.0f} "
                      f"dec_tok/s={dtok_per_sec:.0f} overall_SPS={overall_sps:.0f} "
                      f"peak={peak_mem:.1f}GB")

    def on_train_end(self, args, state, control, **kwargs):
        if args.process_index != 0:
            return
        if not self._step_times:
            return
        elapsed = time.time() - self._train_start
        world_size = int(os.environ.get("WORLD_SIZE", "1"))

        # Skip first 5 warmup steps
        warmup = min(5, len(self._step_times) // 3)
        stable_dt = self._step_times[warmup:] if len(self._step_times) > warmup * 2 else self._step_times
        stable_bs = self._step_batch_sizes[warmup:] if len(self._step_batch_sizes) > warmup * 2 else self._step_batch_sizes
        stable_audio = self._step_audio_secs[warmup:] if len(self._step_audio_secs) > warmup * 2 else self._step_audio_secs

        stable_avg_step = sum(stable_dt) / len(stable_dt) if stable_dt else 1.0
        total_stable_time = sum(stable_dt)
        sps = sum(stable_bs) / total_stable_time if total_stable_time > 0 else 0
        audio_sec_per_sec = sum(stable_audio) / total_stable_time if total_stable_time > 0 else 0
        dtok_per_sec = sps * self.AVG_DECODER_TOKENS

        peak_mem = torch.cuda.max_memory_allocated() / 1e9
        print(f"\n{'='*70}")
        print(f"[PROFILING REPORT]")
        print(f"  Total steps: {len(self._step_times)}")
        print(f"  Total wall time: {elapsed:.1f}s")
        print(f"  Avg step time (stable): {stable_avg_step:.3f}s")
        print(f"  ---")
        print(f"  Samples/sec (SPS):          {sps:.1f} global | {sps/world_size:.1f} per GPU")
        print(f"  Audio-sec/sec:              {audio_sec_per_sec:.0f} global | {audio_sec_per_sec/world_size:.0f} per GPU")
        print(f"  Decoder-tok/sec (est):      {dtok_per_sec:.0f} global | {dtok_per_sec/world_size:.0f} per GPU")
        print(f"  ---")
        print(f"  Peak GPU memory: {peak_mem:.1f} GB")
        print(f"  World size: {world_size}")
        print(f"{'='*70}\n")


def parse_args():
    p = argparse.ArgumentParser("Qwen3-ASR Finetuning on Phase2 parquet+tar dataset")

    # Paths
    p.add_argument("--model_path", type=str, default="Qwen/Qwen3-ASR-1.7B")
    p.add_argument("--train_file", type=str, default="")
    p.add_argument("--eval_file", type=str, default="")
    p.add_argument("--output_dir", type=str, default="./qwen3-asr-phase2-out")

    # Bucketed mode
    p.add_argument("--bucket_dir", type=str, default="",
                   help="Path to bucket parquet directory. If set, uses bucketed dataset.")
    p.add_argument("--bucket_config", type=str, default="",
                   help="Path to bucket_config.json")

    # Split and prompt
    p.add_argument("--train_split", type=str, default="train")
    p.add_argument("--eval_split", type=str, default="dev")
    p.add_argument("--system_prompt", type=str, default="")

    # Audio/text formatting
    p.add_argument("--sr", type=int, default=16000)
    p.add_argument("--language_tag_mode", type=str, default="auto", choices=["auto", "none", "code", "name"])
    p.add_argument("--transcript_has_prefix", type=int, default=0)
    p.add_argument("--skip_empty_transcript", type=int, default=1)
    p.add_argument("--skip_decode_errors", type=int, default=1)
    p.add_argument("--min_duration_s", type=float, default=0.0)
    p.add_argument("--max_duration_s", type=float, default=1200.0)

    # Dataset iteration
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--manifest_batch_rows", type=int, default=65536)
    p.add_argument("--shuffle_row_groups", type=int, default=1)
    p.add_argument("--shuffle_within_batch", type=int, default=1)
    p.add_argument("--max_samples", type=int, default=0)
    p.add_argument("--eval_max_samples", type=int, default=0)
    p.add_argument("--max_open_tars", type=int, default=64)
    p.add_argument("--use_indexed_tar", type=int, default=1, help="Use IndexedTarReader (requires .index.json files)")
    p.add_argument("--worker_decode", type=int, default=1, help="Decode audio in workers (1) or collator (0)")

    # Train hyper-params
    p.add_argument("--batch_size", type=int, default=16)
    p.add_argument("--grad_acc", type=int, default=4)
    p.add_argument("--lr", type=float, default=2e-5)
    p.add_argument("--epochs", type=float, default=1.0)
    p.add_argument("--max_steps", type=int, default=0)
    p.add_argument("--log_steps", type=int, default=10)
    p.add_argument("--lr_scheduler_type", type=str, default="linear")
    p.add_argument("--warmup_ratio", type=float, default=0.02)

    # DataLoader
    p.add_argument("--num_workers", type=int, default=4)
    p.add_argument("--pin_memory", type=int, default=1)
    p.add_argument("--persistent_workers", type=int, default=1)
    p.add_argument("--prefetch_factor", type=int, default=2)

    # Save/eval
    p.add_argument("--save_strategy", type=str, default="steps")
    p.add_argument("--save_steps", type=int, default=200)
    p.add_argument("--save_total_limit", type=int, default=5)
    p.add_argument("--eval_steps", type=int, default=0)

    # Performance
    p.add_argument("--torch_compile", type=int, default=0)
    p.add_argument("--compile_mode", type=str, default="reduce-overhead",
                    choices=["default", "reduce-overhead", "max-autotune"])
    p.add_argument("--gradient_checkpointing", type=int, default=0)
    p.add_argument("--ddp_static_graph", type=int, default=1)
    p.add_argument("--profiling", type=int, default=0, help="Enable detailed profiling callback")
    p.add_argument("--wandb_project", type=str, default="", help="W&B project name (enables wandb logging)")
    p.add_argument("--wandb_run_name", type=str, default="", help="W&B run name")
    p.add_argument("--attn_implementation", type=str, default="flash_attention_2",
                    choices=["sdpa", "flash_attention_2", "eager"],
                    help="Attention implementation (flash_attention_2 recommended for H200)")
    p.add_argument("--bucket_bs", type=str, default="",
                    help="JSON string or path to per-bucket batch sizes, e.g. '{\"b_0_3\":64,\"b_7_10\":24}'")
    p.add_argument("--vram_target_pct", type=float, default=0.85,
                    help="Target vRAM utilization for dynamic BS (0.0-1.0)")
    p.add_argument("--max_batch_seq_len", type=int, default=700,
                   help="Skip tokenized batches whose padded multimodal sequence length exceeds this limit (0 disables)")

    # Resume
    p.add_argument("--ignore_data_skip", type=int, default=0,
                   help="When resuming, do not fast-forward the IterableDataset to the exact saved batch")
    p.add_argument("--resume_from", type=str, default="")
    p.add_argument("--resume", type=int, default=0)
    return p.parse_args()


def auto_max_steps(estimated_rows: int, batch_size: float, grad_acc: int, epochs: float) -> int:
    world_size = int(os.environ.get("WORLD_SIZE", "1"))
    denom = max(1.0, float(batch_size) * grad_acc * world_size)
    return max(1, int(math.ceil((estimated_rows * max(epochs, 0.0)) / denom)))


def main():
    args_cli = parse_args()
    rank = int(os.environ.get("RANK", "0"))

    # Set wandb project via env var (HF Trainer reads this)
    if args_cli.wandb_project:
        os.environ["WANDB_PROJECT"] = args_cli.wandb_project
        if rank == 0:
            print(f"[config] wandb: project={args_cli.wandb_project}, run={args_cli.wandb_run_name or 'auto'}")

    # Determine mode: bucketed or legacy
    use_buckets = bool(args_cli.bucket_dir and args_cli.bucket_config)

    if not use_buckets:
        if not args_cli.train_file:
            raise ValueError("Either --train_file or --bucket_dir + --bucket_config must be provided")
        if not os.path.exists(args_cli.train_file):
            raise FileNotFoundError(f"train_file not found: {args_cli.train_file}")

    use_bf16 = torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8
    attn_impl = args_cli.attn_implementation
    if rank == 0:
        print(f"[config] attn_implementation: {attn_impl}")
    asr_wrapper = Qwen3ASRModel.from_pretrained(
        args_cli.model_path,
        dtype=torch.bfloat16 if use_bf16 else torch.float16,
        device_map=None,
        attn_implementation=attn_impl,
    )
    model = asr_wrapper.model
    processor = asr_wrapper.processor

    patch_outer_forward(model)
    model.generation_config = GenerationConfig.from_model_config(model.config)

    if args_cli.gradient_checkpointing:
        model.gradient_checkpointing_enable()
        if rank == 0:
            print("[config] gradient_checkpointing: ON")
    else:
        if rank == 0:
            print("[config] gradient_checkpointing: OFF (max speed)")

    if args_cli.torch_compile:
        if rank == 0:
            print(f"[config] torch.compile(mode='{args_cli.compile_mode}')")
        model = torch.compile(model, mode=args_cli.compile_mode)

    worker_decode = (args_cli.worker_decode == 1)
    use_indexed = (args_cli.use_indexed_tar == 1)

    if use_buckets:
        with open(args_cli.bucket_config, "r") as f:
            bucket_config = json.load(f)

        # Parse per-bucket batch sizes
        bucket_bs = None
        if args_cli.bucket_bs:
            bs_str = args_cli.bucket_bs.strip()
            if os.path.exists(bs_str):
                with open(bs_str) as f:
                    bucket_bs = json.load(f)
            else:
                bucket_bs = json.loads(bs_str)

        if rank == 0:
            print(f"[config] BUCKETED mode: {len(bucket_config['buckets'])} buckets, "
                  f"{bucket_config.get('total_samples', '?')} samples, "
                  f"{bucket_config.get('total_hours', '?'):.0f}h")
            print(f"[config] worker_decode={worker_decode}, indexed_tar={use_indexed}")

        train_ds = BucketedIterableDataset(
            bucket_dir=args_cli.bucket_dir,
            bucket_config=bucket_config,
            prompt=args_cli.system_prompt,
            seed=args_cli.seed,
            manifest_batch_rows=args_cli.manifest_batch_rows,
            shuffle_row_groups=(args_cli.shuffle_row_groups == 1),
            shuffle_within_batch=(args_cli.shuffle_within_batch == 1),
            skip_empty_transcript=(args_cli.skip_empty_transcript == 1),
            max_open_tars=args_cli.max_open_tars,
            use_indexed_tar=use_indexed,
            target_sr=args_cli.sr,
            bucket_bs=bucket_bs,
            default_batch_size=args_cli.batch_size,
        )

        if rank == 0:
            print("[config] Per-bucket batch sizes:")
            for b in train_ds.buckets:
                print(f"  {b['id']}: BS={b['batch_size']} ({b['samples']:,} samples, {b['weight']:.0f}h)")
    else:
        if rank == 0:
            print(f"[config] LEGACY mode (single parquet): {args_cli.train_file}")
            print(f"[config] worker_decode={worker_decode}, indexed_tar={use_indexed}")

        train_ds = ParquetTarIterableDataset(
            parquet_path=args_cli.train_file,
            split=args_cli.train_split,
            prompt=args_cli.system_prompt,
            seed=args_cli.seed,
            manifest_batch_rows=args_cli.manifest_batch_rows,
            shuffle_row_groups=(args_cli.shuffle_row_groups == 1),
            shuffle_within_batch=(args_cli.shuffle_within_batch == 1),
            min_duration_s=args_cli.min_duration_s,
            max_duration_s=args_cli.max_duration_s,
            skip_empty_transcript=(args_cli.skip_empty_transcript == 1),
            max_samples=args_cli.max_samples,
            max_open_tars=args_cli.max_open_tars,
            use_indexed_tar=use_indexed,
            target_sr=args_cli.sr,
            worker_decode=worker_decode,
        )

    eval_ds = None
    if args_cli.eval_file:
        eval_ds = ParquetTarIterableDataset(
            parquet_path=args_cli.eval_file,
            split=args_cli.eval_split,
            prompt=args_cli.system_prompt,
            seed=args_cli.seed + 1,
            manifest_batch_rows=args_cli.manifest_batch_rows,
            shuffle_row_groups=False,
            shuffle_within_batch=False,
            min_duration_s=args_cli.min_duration_s,
            max_duration_s=args_cli.max_duration_s,
            skip_empty_transcript=(args_cli.skip_empty_transcript == 1),
            max_samples=args_cli.eval_max_samples,
            max_open_tars=args_cli.max_open_tars,
            use_indexed_tar=use_indexed,
            target_sr=args_cli.sr,
            worker_decode=worker_decode,
        )

    max_steps = int(args_cli.max_steps)
    if max_steps <= 0:
        est_rows = train_ds.estimated_num_rows
        if args_cli.max_samples > 0:
            est_rows = min(est_rows, args_cli.max_samples)
        # In pre-batched mode: each step = 1 pre-formed batch of weighted_avg_bs samples per GPU
        # In uniform mode: each step = batch_size * grad_acc samples per GPU
        if use_buckets and hasattr(train_ds, "buckets"):
            weighted_bs = sum(
                float(b["batch_size"]) * float(b.get("norm_weight", 0.0))
                for b in train_ds.buckets
            )
            effective_bs = max(1.0, weighted_bs)
            if rank == 0:
                print(f"[auto] effective dynamic BS={effective_bs:.1f} (weighted by bucket hours)")
            max_steps = auto_max_steps(
                estimated_rows=est_rows,
                batch_size=effective_bs,
                grad_acc=1,  # no grad acc in pre-batched mode
                epochs=args_cli.epochs,
            )
        else:
            max_steps = auto_max_steps(
                estimated_rows=est_rows,
                batch_size=args_cli.batch_size,
                grad_acc=args_cli.grad_acc,
                epochs=args_cli.epochs,
            )
        if rank == 0:
            print(f"[auto] max_steps={max_steps} (estimated_rows={est_rows})")

    eval_steps = args_cli.eval_steps if args_cli.eval_steps > 0 else args_cli.save_steps
    do_eval = bool(eval_ds is not None)

    # In pre-batched mode, DataLoader batch_size=1 yields one pre-formed batch per step.
    # Trainer must use per_device_train_batch_size=1 so each optimizer step = 1 pre-formed batch.
    trainer_bs = 1 if use_buckets else args_cli.batch_size
    trainer_grad_acc = 1 if use_buckets else args_cli.grad_acc

    training_args = TrainingArguments(
        output_dir=args_cli.output_dir,
        per_device_train_batch_size=trainer_bs,
        gradient_accumulation_steps=trainer_grad_acc,
        learning_rate=args_cli.lr,
        num_train_epochs=args_cli.epochs,
        max_steps=max_steps,
        logging_steps=args_cli.log_steps,
        lr_scheduler_type=args_cli.lr_scheduler_type,
        warmup_ratio=args_cli.warmup_ratio,
        dataloader_num_workers=args_cli.num_workers,
        dataloader_pin_memory=(args_cli.pin_memory == 1),
        dataloader_persistent_workers=(args_cli.persistent_workers == 1),
        dataloader_prefetch_factor=args_cli.prefetch_factor if args_cli.num_workers > 0 else None,
        save_strategy=args_cli.save_strategy,
        save_steps=args_cli.save_steps,
        save_total_limit=args_cli.save_total_limit,
        save_safetensors=True,
        eval_strategy="steps" if do_eval else "no",
        eval_steps=eval_steps,
        do_eval=do_eval,
        bf16=use_bf16,
        fp16=not use_bf16,
        ddp_find_unused_parameters=False,
        remove_unused_columns=False,
        report_to="wandb" if args_cli.wandb_project else "none",
        run_name=args_cli.wandb_run_name or None,
        torch_compile=False,
        ignore_data_skip=(args_cli.ignore_data_skip == 1),
    )

    collator = DataCollatorForQwen3ASRPhase2(
        processor=processor,
        sampling_rate=args_cli.sr,
        max_open_tars=args_cli.max_open_tars,
        language_tag_mode=args_cli.language_tag_mode,
        transcript_has_prefix=(args_cli.transcript_has_prefix == 1),
        skip_decode_errors=(args_cli.skip_decode_errors == 1),
        worker_decode=worker_decode,
        max_batch_seq_len=args_cli.max_batch_seq_len,
    )

    callbacks = [MakeEveryCheckpointInferableCallback(base_model_path=args_cli.model_path)]
    callbacks.append(R2MilestoneCallback(milestone_interval=50000))
    if args_cli.profiling:
        callbacks.append(ProfilingCallback())

    trainer = CastFloatInputsTrainer(
        model=model,
        args=training_args,
        train_dataset=train_ds,
        eval_dataset=eval_ds,
        data_collator=collator,
        tokenizer=processor.tokenizer,
        callbacks=callbacks,
    )

    resume_from = (args_cli.resume_from or "").strip()
    if not resume_from and args_cli.resume == 1:
        resume_from = find_latest_checkpoint(training_args.output_dir) or ""

    if resume_from:
        if trainer.args.process_index == 0:
            print(f"[resume] resume_from_checkpoint = {resume_from}")
            print(f"[resume] ignore_data_skip = {int(training_args.ignore_data_skip)}")
        trainer.train(resume_from_checkpoint=resume_from)
    else:
        trainer.train()


if __name__ == "__main__":
    main()
