"""Calibrate per-bucket batch sizes by running short trials.

For each bucket, runs 5 training steps at increasing BS until OOM.
Reports max safe BS per bucket.

Usage:
    # Single GPU quick calibration
    CUDA_VISIBLE_DEVICES=0 python3 tools/calibrate_bucket_bs.py

    # Multi-GPU verification
    torchrun --nproc_per_node=8 tools/calibrate_bucket_bs.py --mode verify
"""
import os, sys, json, time
os.environ["OMP_NUM_THREADS"] = "1"

import torch
import numpy as np
from pathlib import Path

BUCKET_DIR = Path("/root/gemini-asr/lf_asr/artifacts/phase2/buckets")
MODEL_PATH = "/root/data/qwen3_asr_weights"
MAX_MEM_GB = 130  # Leave 13GB headroom on 143GB H200

# Initial BS estimates per bucket (conservative)
INITIAL_BS = {
    "b_0_3":  96,
    "b_3_5":  64,
    "b_5_7":  48,
    "b_7_10": 32,
    "b_10_15": 24,
    "b_15_20": 16,
    "b_20_30": 8,
}


def test_batch_size(model, processor, bucket_id, parquet_path, batch_size, steps=3):
    """Run a few steps with the given BS and return peak memory in GB, or None if OOM."""
    import pyarrow.parquet as pq
    import soundfile as sf
    import io

    torch.cuda.reset_peak_memory_stats()
    torch.cuda.empty_cache()

    try:
        # Load a chunk of data from this bucket
        df = pq.read_table(parquet_path, columns=["tar_path", "tar_member_name", "transcript",
                                                    "language", "duration_s"]).to_pandas()
        df = df.head(batch_size * steps * 2)  # Get enough rows

        from finetuning.qwen3_asr_sft_phase2 import (
            IndexedTarReader, _decode_audio_bytes, clean_transcript,
            detect_script_language, format_target_text, build_prefix_messages,
            _get_feat_extract_output_lengths,
        )

        tar_reader = IndexedTarReader(max_open_files=32)

        # Build batches
        audios = []
        for _, row in df.iterrows():
            try:
                raw = tar_reader.read_member(row["tar_path"], row["tar_member_name"])
                wav = _decode_audio_bytes(raw, 16000)
                if wav is not None and len(wav) > 0:
                    audios.append({
                        "waveform": wav,
                        "transcript": clean_transcript(str(row["transcript"])),
                        "language": str(row["language"]),
                    })
            except Exception:
                continue
            if len(audios) >= batch_size * steps:
                break

        tar_reader.close()

        if len(audios) < batch_size:
            return None, "not enough samples"

        # Run forward+backward for a few steps
        model.train()
        optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

        prefix_msgs = build_prefix_messages("", None)
        prefix_text = processor.apply_chat_template(
            [prefix_msgs], add_generation_prompt=True, tokenize=False
        )[0]
        eos = processor.tokenizer.eos_token or ""

        for step in range(steps):
            batch = audios[step * batch_size:(step + 1) * batch_size]
            if len(batch) < batch_size:
                break

            wavs = [s["waveform"] for s in batch]
            texts = []
            for s in batch:
                lang = detect_script_language(s["transcript"], s["language"])
                target = format_target_text(s["transcript"], lang, "auto", False)
                texts.append(prefix_text + target + eos)

            inputs = processor(
                text=texts, audio=wavs, return_tensors="pt",
                padding=True, truncation=False,
            )

            # Build labels
            audio_token_lens = _get_feat_extract_output_lengths(
                inputs["feature_attention_mask"].sum(dim=1)
            ).tolist()
            prefix_expanded = processor.replace_multimodal_special_tokens(
                [prefix_text] * len(wavs), iter(audio_token_lens),
            )
            prefix_tok = processor.tokenizer(
                prefix_expanded, return_tensors="pt", padding=True, truncation=False,
            )
            prefix_lens = prefix_tok["attention_mask"].sum(dim=1).tolist()

            labels = inputs["input_ids"].clone()
            for i, pl in enumerate(prefix_lens):
                valid = torch.nonzero(inputs["attention_mask"][i], as_tuple=True)[0]
                labels[i, valid[:pl]] = -100
            pad_id = processor.tokenizer.pad_token_id
            if pad_id is not None:
                labels[labels == pad_id] = -100

            # Move to GPU
            device = next(model.parameters()).device
            dtype = next(model.parameters()).dtype
            gpu_inputs = {}
            for k, v in inputs.items():
                if torch.is_tensor(v):
                    v = v.to(device)
                    if v.is_floating_point():
                        v = v.to(dtype=dtype)
                gpu_inputs[k] = v
            gpu_inputs["labels"] = labels.to(device)

            outputs = model(**gpu_inputs)
            loss = outputs.loss if hasattr(outputs, 'loss') else outputs[0]
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        peak_mem = torch.cuda.max_memory_allocated() / 1e9
        return peak_mem, "ok"

    except torch.cuda.OutOfMemoryError:
        torch.cuda.empty_cache()
        return None, "OOM"
    except Exception as e:
        torch.cuda.empty_cache()
        return None, f"error: {e}"


def main():
    import argparse
    p = argparse.ArgumentParser()
    p.add_argument("--mode", default="calibrate", choices=["calibrate", "verify"])
    p.add_argument("--gc", type=int, default=1, help="Gradient checkpointing")
    args = p.parse_args()

    # Load bucket config
    with open(BUCKET_DIR / "bucket_config.json") as f:
        config = json.load(f)

    print("Loading model...")
    from qwen_asr import Qwen3ASRModel
    from finetuning.qwen3_asr_sft_phase2 import patch_outer_forward
    from transformers import GenerationConfig

    use_bf16 = torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8
    asr = Qwen3ASRModel.from_pretrained(
        MODEL_PATH, dtype=torch.bfloat16 if use_bf16 else torch.float16, device_map=None,
    )
    model = asr.model
    processor = asr.processor
    patch_outer_forward(model)
    model.generation_config = GenerationConfig.from_model_config(model.config)

    if args.gc:
        model.gradient_checkpointing_enable()
        print("[config] gradient_checkpointing: ON")

    model = model.to("cuda")

    results = {}
    for b in config["buckets"]:
        bid = b["bucket_id"]
        parquet_path = str(BUCKET_DIR / f"{bid}.parquet")
        if not os.path.exists(parquet_path):
            print(f"  SKIP {bid}: parquet not found")
            continue

        initial_bs = INITIAL_BS.get(bid, 16)

        if args.mode == "calibrate":
            # Binary search for max BS
            lo, hi = 8, initial_bs * 2
            best_bs = 8
            best_mem = None

            print(f"\n--- {bid} (range: {b.get('range', '?')}, samples: {b['samples']:,}) ---")

            while lo <= hi:
                mid = (lo + hi) // 2
                # Round to multiple of 8
                mid = max(8, (mid // 8) * 8)

                print(f"  Testing BS={mid}...", end=" ", flush=True)
                peak_mem, status = test_batch_size(model, processor, bid, parquet_path, mid, steps=3)

                if peak_mem is not None and peak_mem < MAX_MEM_GB:
                    print(f"OK (peak={peak_mem:.1f}GB)")
                    best_bs = mid
                    best_mem = peak_mem
                    lo = mid + 8
                else:
                    print(f"{status} (peak={peak_mem:.1f}GB)" if peak_mem else f"{status}")
                    hi = mid - 8

                torch.cuda.empty_cache()

            results[bid] = {"batch_size": best_bs, "peak_mem_gb": best_mem, "grad_acc": 1}
            print(f"  -> Best BS for {bid}: {best_bs} (peak={best_mem:.1f}GB)" if best_mem else f"  -> Best BS for {bid}: {best_bs}")

        else:
            # Verify mode: just test the initial BS
            print(f"\n--- Verifying {bid} BS={initial_bs} ---")
            peak_mem, status = test_batch_size(model, processor, bid, parquet_path, initial_bs, steps=3)
            if peak_mem:
                print(f"  {bid}: BS={initial_bs} peak={peak_mem:.1f}GB - {status}")
                results[bid] = {"batch_size": initial_bs, "peak_mem_gb": peak_mem, "grad_acc": 1}
            else:
                print(f"  {bid}: BS={initial_bs} - {status}")
                results[bid] = {"batch_size": initial_bs // 2, "peak_mem_gb": None, "grad_acc": 1, "note": status}
            torch.cuda.empty_cache()

    # Save results
    out_path = BUCKET_DIR / "calibrated_bs.json"
    with open(out_path, "w") as f:
        json.dump(results, f, indent=2)
    print(f"\nCalibration results saved to: {out_path}")
    print(json.dumps(results, indent=2))


if __name__ == "__main__":
    main()
