"""SFT shard worker: processes pre-segmented audio shards into XCodec2 tokens.

SFT data is already segmented FLAC in tar shards. No VAD needed. Flow:

  1. Claim shard from Supabase (atomic, FOR UPDATE SKIP LOCKED)
  2. Download audio.tar from R2 (multi-threaded range GET)
  3. Extract FLACs in memory → decode to 16kHz waveforms
  4. XCodec2 encode (same 6s overlap chunking + center-cut stitch as pretraining)
  5. Upload xcodec2_tokens.parquet back to the same R2 shard folder
  6. Mark shard DONE in Supabase

Uses the custom 198k-step fine-tuned XCodec2 checkpoint (downloaded from R2
xcodec bucket on first run), NOT the base HuggingFace model.

Resilience features:
  - Prefetch: next shard downloads in background while current one encodes.
  - Checkpointing: every CKPT_INTERVAL segments, flush to local parquet.
    On OOM/crash/restart, resume from last checkpoint — only re-encode
    the segments after the checkpoint, not the whole shard.
  - Self-recovery: OOM triggers CUDA cache clear + resume from checkpoint.
    Transient errors retry with backoff, not full shard restart.
"""

from __future__ import annotations

import gc
import hashlib
import io
import logging
import os
import signal
import tarfile
import time
import traceback
from concurrent.futures import ThreadPoolExecutor, Future
from dataclasses import dataclass, field
from pathlib import Path

import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
import torch
import torchaudio

from codecbench.pipeline.config import PipelineConfig
from codecbench.pipeline.encoder import HotEncoder, EncodedSegment
from codecbench.pipeline.vad import Segment
from codecbench.pipeline.sft_supabase import SFTOrchestrator

logger = logging.getLogger(__name__)

TOKENS_PER_SEC = 50  # 16000 / 320
CKPT_INTERVAL = 1000  # flush checkpoint every N segments


@dataclass
class ShardResult:
    """Encoding results for one shard."""
    shard_key: str
    segment_ids: list[str]
    xcodec2_tokens: list[np.ndarray]
    token_counts: list[int]
    total_audio_s: float
    total_encode_ms: float
    segments_encoded: int
    segments_failed: int


@dataclass
class PreparedShard:
    """A shard whose tar has been downloaded and may later be decoded."""
    shard_row: dict
    shard_key: str
    tar_bytes: bytes
    decoded: list[tuple[str, Segment]] | None = None
    download_s: float = 0.0
    decode_s: float = 0.0


# ── Checkpoint helpers ───────────────────────────────────────────────

def _ckpt_path(tmp_dir: str, shard_key: str) -> Path:
    """Deterministic local checkpoint path for a shard."""
    h = hashlib.md5(shard_key.encode()).hexdigest()[:12]
    return Path(tmp_dir) / f"sft_ckpt_{h}.parquet"


def _save_checkpoint(
    path: Path,
    segment_ids: list[str],
    tokens: list[np.ndarray],
    token_counts: list[int],
) -> None:
    token_bytes = [t.tobytes() for t in tokens]
    table = pa.table({
        "segment_id": pa.array(segment_ids, type=pa.string()),
        "xcodec2_tokens": pa.array(token_bytes, type=pa.binary()),
        "token_count": pa.array(token_counts, type=pa.int32()),
    })
    path.parent.mkdir(parents=True, exist_ok=True)
    pq.write_table(table, path, compression="zstd")
    logger.info("Checkpoint saved: %d segments → %s (%.1f MB)",
                len(segment_ids), path.name, path.stat().st_size / 1e6)


def _load_checkpoint(path: Path) -> tuple[set[str], list[str], list[np.ndarray], list[int]]:
    """Load checkpoint. Returns (done_ids_set, segment_ids, tokens, counts)."""
    table = pq.read_table(path)
    seg_ids = table.column("segment_id").to_pylist()
    counts = table.column("token_count").to_pylist()
    raw_bytes = table.column("xcodec2_tokens").to_pylist()
    tokens = [np.frombuffer(b, dtype=np.uint16).copy() for b in raw_bytes]
    logger.info("Checkpoint loaded: %d segments from %s", len(seg_ids), path.name)
    return set(seg_ids), seg_ids, tokens, counts


# ── Worker ───────────────────────────────────────────────────────────

class SFTWorker:
    """Processes SFT audio shards: R2 tar → XCodec2 tokens → R2 upload."""

    def __init__(self, cfg: PipelineConfig):
        self.cfg = cfg
        self.encoder = HotEncoder(cfg.codec)
        self.orch: SFTOrchestrator | None = None
        self._s3 = None
        self._dl_workers = 4
        self._running = False
        self._total_shards = 0
        self._total_audio_s = 0.0
        self._total_segments = 0
        self._consecutive_errors = 0

        if not cfg.worker.worker_id:
            gpu_name = "unknown"
            if torch.cuda.is_available():
                gpu_name = torch.cuda.get_device_name(0).replace(" ", "_")
            cfg.worker.gpu_name = gpu_name
            cfg.worker.worker_id = (
                f"sft_{cfg.worker.offer_id}_{gpu_name}_{os.getpid()}"
            )

    def setup(self) -> None:
        """One-time setup: DB, custom checkpoint download, model loading."""
        import boto3
        from botocore.config import Config as BotoConfig
        from dotenv import load_dotenv
        load_dotenv()

        logger.info("=== SFT Worker Setup ===")
        logger.info("Worker ID: %s", self.cfg.worker.worker_id)

        cpus = os.cpu_count() or 4
        self._dl_workers = max(cpus * 3 // 4, 4)

        self._s3 = boto3.client(
            "s3",
            endpoint_url=os.environ["R2_ENDPOINT_URL"],
            aws_access_key_id=os.environ["R2_ACCESS_KEY_ID"],
            aws_secret_access_key=os.environ["R2_SECRET_ACCESS_KEY"],
            region_name="auto",
            config=BotoConfig(
                max_pool_connections=self._dl_workers + 4,
                retries={"max_attempts": 5, "mode": "adaptive"},
            ),
        )

        self.orch = SFTOrchestrator()
        self.orch.ensure_tables()
        self.orch.create_claim_rpc()

        # Custom 198k-step XCodec2 checkpoint from R2
        if self.cfg.r2.xcodec_ckpt_key and not self.cfg.codec.xcodec2_custom_ckpt:
            tmp_dir = Path(self.cfg.worker.local_tmp_dir)
            tmp_dir.mkdir(parents=True, exist_ok=True)
            ckpt_local = tmp_dir / "xcodec2_custom.ckpt"
            if not ckpt_local.exists():
                logger.info(
                    "Downloading custom XCodec2 checkpoint: %s/%s",
                    self.cfg.r2.xcodec_bucket, self.cfg.r2.xcodec_ckpt_key,
                )
                self._s3.download_file(
                    self.cfg.r2.xcodec_bucket,
                    self.cfg.r2.xcodec_ckpt_key,
                    str(ckpt_local),
                )
                logger.info("Custom checkpoint: %.1f MB", ckpt_local.stat().st_size / 1e6)
            else:
                logger.info("Custom checkpoint cached: %s", ckpt_local)
            self.cfg.codec.xcodec2_custom_ckpt = str(ckpt_local)

        logger.info("Loading XCodec2 (custom_ckpt=%s)...", bool(self.cfg.codec.xcodec2_custom_ckpt))
        self.encoder.load()

        vram = torch.cuda.memory_allocated() / 1e6
        logger.info(
            "=== Setup Complete. VRAM: %.0f MB | vCPUs: %d | dl_workers: %d ===",
            vram, cpus, self._dl_workers,
        )

    # ── Download + decode ────────────────────────────────────────────

    def _download_tar(self, shard_key: str) -> bytes:
        """Download audio.tar with parallel range GETs scaled to instance vCPUs."""
        bucket = os.environ.get("R2_BUCKET_DESTINATION", "finalsftdata")
        tar_key = f"{shard_key}audio.tar"

        t0 = time.perf_counter()
        head = self._s3.head_object(Bucket=bucket, Key=tar_key)
        total_size = head["ContentLength"]

        chunk_size = 64 * 1024 * 1024
        if total_size > 50 * 1024 * 1024:
            n_parts = (total_size + chunk_size - 1) // chunk_size
            parts = [None] * n_parts

            def _dl_part(idx: int):
                start = idx * chunk_size
                end = min(start + chunk_size - 1, total_size - 1)
                resp = self._s3.get_object(
                    Bucket=bucket, Key=tar_key,
                    Range=f"bytes={start}-{end}",
                )
                parts[idx] = resp["Body"].read()

            with ThreadPoolExecutor(max_workers=self._dl_workers) as pool:
                list(pool.map(_dl_part, range(n_parts)))
            body = b"".join(parts)
        else:
            resp = self._s3.get_object(Bucket=bucket, Key=tar_key)
            body = resp["Body"].read()

        elapsed = time.perf_counter() - t0
        logger.info(
            "Downloaded %s: %.1f MB in %.1fs (%.0f MB/s, %d workers)",
            tar_key, len(body) / 1e6, elapsed,
            len(body) / elapsed / 1e6, self._dl_workers,
        )
        return body

    def _extract_and_decode(
        self, tar_bytes: bytes, target_sr: int = 16_000,
    ) -> list[tuple[str, Segment]]:
        """Extract FLACs from tar, decode to sorted waveform Segments."""
        t0 = time.perf_counter()
        flac_map: dict[str, bytes] = {}
        with tarfile.open(fileobj=io.BytesIO(tar_bytes), mode="r") as tar:
            for member in tar.getmembers():
                if member.isfile() and member.name.endswith(".flac"):
                    f = tar.extractfile(member)
                    if f:
                        flac_map[member.name] = f.read()
        extract_s = time.perf_counter() - t0

        t1 = time.perf_counter()
        results = []
        for name in sorted(flac_map.keys()):
            seg_id = Path(name).stem
            try:
                wav, sr = torchaudio.load(io.BytesIO(flac_map[name]))
                if sr != target_sr:
                    wav = torchaudio.functional.resample(wav, sr, target_sr)
                if wav.shape[0] > 1:
                    wav = wav.mean(dim=0, keepdim=True)
                duration_s = wav.shape[-1] / target_sr
                results.append((seg_id, Segment(
                    start_s=0.0, end_s=duration_s, audio=wav,
                )))
            except Exception as e:
                logger.warning("Failed to decode FLAC %s: %s", name, e)
        decode_s = time.perf_counter() - t1

        logger.info(
            "Extracted %d FLACs in %.1fs, decoded %d in %.1fs",
            len(flac_map), extract_s, len(results), decode_s,
        )
        return results

    def _prepare_shard(self, shard_row: dict) -> PreparedShard | None:
        """Download only. Keep decode off the prefetch thread.

        Decoding the next shard concurrently with GPU encode cut throughput
        roughly in half in local isolation tests. We only overlap download.
        """
        shard_key = shard_row["shard_key"]
        try:
            t0 = time.perf_counter()
            tar_bytes = self._download_tar(shard_key)
            download_s = time.perf_counter() - t0

            return PreparedShard(
                shard_row=shard_row, shard_key=shard_key,
                tar_bytes=tar_bytes, download_s=download_s,
            )
        except Exception as e:
            logger.error("Failed to prepare shard %s: %s", shard_key, e)
            return None

    def _decode_prepared(self, prepared: PreparedShard) -> PreparedShard | None:
        """Decode the current shard on the main thread after download overlap."""
        try:
            t0 = time.perf_counter()
            decoded = self._extract_and_decode(prepared.tar_bytes, self.cfg.codec.target_sr)
            prepared.decode_s = time.perf_counter() - t0
            prepared.tar_bytes = b""
            prepared.decoded = decoded
            if not decoded:
                logger.warning("No segments decoded from %s", prepared.shard_key)
                return None
            return prepared
        except Exception as e:
            logger.error("Failed to decode shard %s: %s", prepared.shard_key, e)
            return None

    # ── Encode with checkpointing + self-recovery ────────────────────

    def _encode_prepared(self, prepared: PreparedShard) -> ShardResult | None:
        """Encode a prepared shard. Checkpoints every CKPT_INTERVAL segments.

        On OOM or transient error: saves checkpoint, clears CUDA, resumes
        from the checkpoint — only re-encodes segments AFTER the failure point.
        """
        t0 = time.perf_counter()
        tmp_dir = self.cfg.worker.local_tmp_dir
        ckpt_file = _ckpt_path(tmp_dir, prepared.shard_key)

        # ── Resume from checkpoint if one exists ──
        done_ids: set[str] = set()
        segment_ids: list[str] = []
        all_tokens: list[np.ndarray] = []
        token_counts: list[int] = []
        total_audio_s = 0.0
        total_encode_ms = 0.0

        if ckpt_file.exists():
            done_ids, segment_ids, all_tokens, token_counts = _load_checkpoint(ckpt_file)
            for t_arr in all_tokens:
                total_audio_s += len(t_arr) / TOKENS_PER_SEC
            logger.info(
                "Resuming shard %s from checkpoint: %d segments already done",
                prepared.shard_key, len(done_ids),
            )

        assert prepared.decoded is not None, "Call _decode_prepared() first"
        remaining = [(sid, seg) for sid, seg in prepared.decoded if sid not in done_ids]
        segments_failed = 0
        max_seg_duration = 0.0
        since_last_ckpt = 0

        logger.info(
            "Encoding %s: %d remaining (%d already checkpointed)",
            prepared.shard_key, len(remaining), len(done_ids),
        )

        batch_size = 256
        i = 0
        while i < len(remaining):
            batch = remaining[i : i + batch_size]
            seg_ids_batch = [sid for sid, _ in batch]
            segments_batch = [seg for _, seg in batch]

            try:
                encoded = self.encoder.encode_segments(
                    segments_batch,
                    xcodec_batch_size_override=self.cfg.codec.xcodec_batch_size,
                )
            except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
                if "out of memory" not in str(e).lower():
                    raise

                # OOM recovery: save what we have, clear GPU, retry this batch at B=1
                logger.warning(
                    "OOM at segment %d/%d (batch=%d). Saving checkpoint, clearing CUDA...",
                    len(segment_ids), len(remaining) + len(done_ids), len(batch),
                )
                if segment_ids:
                    _save_checkpoint(ckpt_file, segment_ids, all_tokens, token_counts)
                torch.cuda.empty_cache()
                gc.collect()
                time.sleep(2)

                # Retry this batch at B=1
                try:
                    encoded = self.encoder.encode_segments(
                        segments_batch, xcodec_batch_size_override=1,
                    )
                except Exception as inner_e:
                    logger.error("B=1 retry also failed: %s", inner_e)
                    segments_failed += len(batch)
                    torch.cuda.empty_cache()
                    gc.collect()
                    i += batch_size
                    continue

            enc_idx = 0
            for sid, seg in zip(seg_ids_batch, segments_batch):
                if enc_idx < len(encoded):
                    es = encoded[enc_idx]
                    enc_idx += 1
                    tokens_np = es.xcodec2_tokens.squeeze(0).numpy().astype(np.uint16)
                    segment_ids.append(sid)
                    all_tokens.append(tokens_np)
                    token_counts.append(len(tokens_np))
                    total_audio_s += seg.duration_s
                    total_encode_ms += es.encode_time_ms
                    max_seg_duration = max(max_seg_duration, seg.duration_s)
                    since_last_ckpt += 1
                else:
                    segments_failed += 1

            # Periodic checkpoint every CKPT_INTERVAL segments
            if since_last_ckpt >= CKPT_INTERVAL:
                _save_checkpoint(ckpt_file, segment_ids, all_tokens, token_counts)
                since_last_ckpt = 0

            if (i // batch_size) % 4 == 3:
                torch.cuda.empty_cache()

            i += batch_size

        del remaining
        prepared.decoded = []
        gc.collect()

        # Clean up checkpoint on success (final parquet will be uploaded)
        if ckpt_file.exists():
            ckpt_file.unlink()

        elapsed = time.perf_counter() - t0
        vram_mb = torch.cuda.memory_allocated() / 1e6
        vram_peak_mb = torch.cuda.max_memory_allocated() / 1e6

        logger.info(
            "Encoded shard %s: %d/%d segs, %.0fs audio, %.1fs wall, "
            "RTF=%.0fx, max_seg=%.1fs, VRAM=%.0f/%.0f MB (cur/peak)",
            prepared.shard_key, len(segment_ids),
            len(segment_ids) + segments_failed,
            total_audio_s, elapsed, total_audio_s / max(elapsed, 0.001),
            max_seg_duration, vram_mb, vram_peak_mb,
        )

        if not segment_ids:
            return None

        return ShardResult(
            shard_key=prepared.shard_key,
            segment_ids=segment_ids,
            xcodec2_tokens=all_tokens,
            token_counts=token_counts,
            total_audio_s=total_audio_s,
            total_encode_ms=total_encode_ms,
            segments_encoded=len(segment_ids),
            segments_failed=segments_failed,
        )

    # ── Upload ───────────────────────────────────────────────────────

    def _upload_tokens(self, result: ShardResult) -> str:
        """Pack tokens into parquet and upload alongside original shard."""
        token_bytes_list = [t.tobytes() for t in result.xcodec2_tokens]

        table = pa.table({
            "segment_id": pa.array(result.segment_ids, type=pa.string()),
            "xcodec2_tokens": pa.array(token_bytes_list, type=pa.binary()),
            "token_count": pa.array(result.token_counts, type=pa.int32()),
        })

        buf = io.BytesIO()
        pq.write_table(table, buf, compression="zstd")
        parquet_bytes = buf.getvalue()

        bucket = os.environ.get("R2_BUCKET_DESTINATION", "finalsftdata")
        upload_key = f"{result.shard_key}xcodec2_tokens.parquet"

        self._s3.put_object(
            Bucket=bucket, Key=upload_key,
            Body=parquet_bytes,
            ContentType="application/octet-stream",
        )

        logger.info(
            "Uploaded %s: %.1f MB, %d segments",
            upload_key, len(parquet_bytes) / 1e6, len(result.segment_ids),
        )
        return upload_key

    # ── Main loop with prefetch ──────────────────────────────────────

    def run(self, max_shards: int | None = None, dataset: str | None = None,
            language: str | None = None, benchmark: bool = False) -> None:
        """Main loop with prefetch: download shard N+1 while encoding shard N.

        benchmark=True: process exactly one shard, but also prefetch the next
        shard's download so we can confirm overlap is working without paying
        the cost of encoding a second shard.
        """
        self._running = True
        shards_done = 0
        timing_log: list[dict] = []

        if benchmark:
            max_shards = 1
            torch.cuda.reset_peak_memory_stats()

        def _shutdown(sig, frame):
            logger.info("Received signal %s, shutting down...", sig)
            self._running = False
        signal.signal(signal.SIGTERM, _shutdown)
        signal.signal(signal.SIGINT, _shutdown)

        logger.info(
            "=== SFT Worker Starting (max=%s, dataset=%s, lang=%s, benchmark=%s) ===",
            max_shards, dataset, language, benchmark,
        )

        prefetch_future: Future | None = None
        prefetch_pool = ThreadPoolExecutor(max_workers=1, thread_name_prefix="prefetch")

        def _claim_and_prepare() -> PreparedShard | None:
            shard = self.orch.claim_shard(
                self.cfg.worker.worker_id, dataset=dataset, language=language,
            )
            if shard is None:
                return None
            self.orch.update_shard_status(shard["shard_key"], "PROCESSING", {
                "claimed_by": self.cfg.worker.worker_id,
            })
            result = self._prepare_shard(shard)
            if result is None:
                # Download failed — release shard back to PENDING so another worker can try
                try:
                    self.orch.update_shard_status(shard["shard_key"], "PENDING", {
                        "claimed_by": None, "claimed_at": None, "started_at": None,
                    })
                except Exception:
                    pass
            return result

        prefetch_future = prefetch_pool.submit(_claim_and_prepare)
        _no_work_retries = 0
        _MAX_NO_WORK_RETRIES = 10

        while self._running:
            if max_shards and shards_done >= max_shards:
                logger.info("Reached max_shards=%d, stopping", max_shards)
                break

            if prefetch_future is None:
                prefetch_future = prefetch_pool.submit(_claim_and_prepare)
            prepared = prefetch_future.result()
            prefetch_future = None

            if prepared is None:
                _no_work_retries += 1
                if _no_work_retries >= _MAX_NO_WORK_RETRIES:
                    logger.info("No shards after %d retries, stopping", _no_work_retries)
                    break
                logger.warning(
                    "Prepare returned None (attempt %d/%d), retrying in 30s...",
                    _no_work_retries, _MAX_NO_WORK_RETRIES,
                )
                time.sleep(30)
                continue
            _no_work_retries = 0

            needs_more = benchmark or (max_shards is None) or (shards_done + 1 < max_shards)
            if needs_more and self._running:
                prefetch_future = prefetch_pool.submit(_claim_and_prepare)
                logger.info("Prefetch: next shard download queued while processing %s", prepared.shard_key)

            try:
                prepared = self._decode_prepared(prepared)
                if prepared is None:
                    self._consecutive_errors += 1
                    continue

                if prefetch_future is not None and prefetch_future.done():
                    logger.info("Prefetch download completed before encode start for %s", prepared.shard_key)

                shard_t0 = time.perf_counter()
                result = self._encode_prepared(prepared)
                encode_s = time.perf_counter() - shard_t0

                if result is None:
                    self.orch.update_shard_status(prepared.shard_key, "FAILED", {
                        "error_detail": "Encode returned no results",
                    })
                    self._consecutive_errors += 1
                    continue

                upload_t0 = time.perf_counter()
                r2_key = self._upload_tokens(result)
                upload_s = time.perf_counter() - upload_t0

                self.orch.update_shard_status(prepared.shard_key, "DONE", {
                    "segments_encoded": result.segments_encoded,
                    "segments_failed": result.segments_failed,
                    "total_audio_s": round(result.total_audio_s, 1),
                    "total_encode_ms": round(result.total_encode_ms, 1),
                    "output_r2_key": r2_key,
                })

                self._total_shards += 1
                self._total_audio_s += result.total_audio_s
                self._total_segments += result.segments_encoded
                self._consecutive_errors = 0
                shards_done += 1

                shard_timing = {
                    "shard": prepared.shard_key,
                    "download_s": prepared.download_s,
                    "decode_s": prepared.decode_s,
                    "encode_s": encode_s,
                    "upload_s": upload_s,
                    "audio_s": result.total_audio_s,
                    "segments": result.segments_encoded,
                    "rtf": result.total_audio_s / max(encode_s, 0.001),
                }
                timing_log.append(shard_timing)

                logger.info(
                    "Shard %d done: dl=%.1fs dec=%.1fs enc=%.1fs up=%.1fs | "
                    "%.0fs audio, %d segs, RTF=%.0fx",
                    shards_done, prepared.download_s, prepared.decode_s,
                    encode_s, upload_s, result.total_audio_s,
                    result.segments_encoded, shard_timing["rtf"],
                )

                if benchmark:
                    prefetched_ready = False
                    prefetched = None
                    if prefetch_future is not None:
                        if not prefetch_future.done():
                            logger.info("Waiting for prefetched shard download to finish for benchmark...")
                        prefetched = prefetch_future.result()
                        prefetched_ready = prefetched is not None
                    shard_timing["next_download_ready"] = prefetched_ready
                    if prefetched_ready and prefetched is not None:
                        logger.info("Prefetch confirmed ready: %s", prefetched.shard_key)
                        # Benchmark mode stops after shard 1. Release the prefetched claim.
                        self.orch.update_shard_status(prefetched.shard_key, "PENDING", {
                            "claimed_by": None,
                            "claimed_at": None,
                            "started_at": None,
                        })
                    break

            except Exception as e:
                error_msg = f"{type(e).__name__}: {e}"
                logger.error("Shard %s failed: %s\n%s", prepared.shard_key, error_msg, traceback.format_exc())
                try:
                    self.orch.update_shard_status(prepared.shard_key, "FAILED", {
                        "error_detail": error_msg[:2000],
                    })
                except Exception:
                    pass
                self._consecutive_errors += 1
                if "out of memory" in str(e).lower():
                    torch.cuda.empty_cache()
                    gc.collect()

            if self._consecutive_errors >= self.cfg.worker.max_retries:
                logger.error("Too many errors (%d), pausing 60s", self._consecutive_errors)
                time.sleep(60)
                self._consecutive_errors = 0

        if prefetch_future and not prefetch_future.done():
            prefetch_future.cancel()
        prefetch_pool.shutdown(wait=False)

        if benchmark and timing_log:
            self._print_benchmark_report(timing_log)

        logger.info(
            "=== SFT Worker Done. %d shards, %d segments, %.0fs audio ===",
            shards_done, self._total_segments, self._total_audio_s,
        )

    def _print_benchmark_report(self, timing_log: list[dict]) -> None:
        """Detailed benchmark report for fleet scaling estimates."""
        vram_peak = torch.cuda.max_memory_allocated() / 1e6
        gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU"
        vram_total = torch.cuda.get_device_properties(0).total_memory / 1e6 if torch.cuda.is_available() else 0

        rep = timing_log[-1] if len(timing_log) > 1 else timing_log[0]
        next_download_ready = bool(rep.get("next_download_ready", False))
        if next_download_ready:
            effective = rep["decode_s"] + rep["encode_s"] + rep["upload_s"]
        else:
            effective = rep["download_s"] + rep["decode_s"] + rep["encode_s"] + rep["upload_s"]

        from codecbench.pipeline.sft_supabase import SFTOrchestrator
        orch = SFTOrchestrator()
        stats = orch.get_stats()
        total_pending = sum(r["cnt"] for r in stats if r["status"] == "PENDING")

        print(f"\n{'='*70}")
        print(f"  SFT BENCHMARK REPORT")
        print(f"{'='*70}")
        print(f"  GPU: {gpu_name}")
        print(f"  VRAM: {vram_peak:.0f} / {vram_total:.0f} MB (peak / total)")
        print(f"  Batch size: {self.cfg.codec.xcodec_batch_size}")
        print(f"  Download workers: {self._dl_workers} (of {os.cpu_count()} vCPUs)")
        print(f"  Checkpoint interval: {CKPT_INTERVAL} segments")
        print()

        for i, t in enumerate(timing_log):
            pfx = "  " if i == 0 else "  "
            label = "(benchmark shard)" if i == 0 else "(extra)"
            print(f"{pfx}Shard {i+1} {label}:")
            print(f"    Download:  {t['download_s']:>7.1f}s")
            print(f"    Decode:    {t['decode_s']:>7.1f}s")
            print(f"    Encode:    {t['encode_s']:>7.1f}s (RTF={t['rtf']:.0f}x)")
            print(f"    Upload:    {t['upload_s']:>7.1f}s")
            print(f"    Audio:     {t['audio_s']:>7.0f}s ({t['audio_s']/3600:.1f}h)")
            print(f"    Segments:  {t['segments']}")
            print(f"    Next download ready by upload finish: {bool(t.get('next_download_ready', False))}")
            print()

        print(f"  Effective per shard: {effective:.1f}s")
        if next_download_ready:
            print("  Download is hidden by overlap; decode+encode+upload dominate.")
        else:
            print("  Download was not fully hidden; full pipeline time is used.")
        print()

        eta_1gpu_h = (total_pending * effective) / 3600
        for n in [1, 10, 50, 100, 200]:
            eta_h = eta_1gpu_h / n
            print(f"  ETA {n:>3} GPUs: {eta_h:>7.1f}h ({eta_h/24:.1f} days) for {total_pending} shards")

        print(f"{'='*70}\n")
