#!/usr/bin/env python3
"""
Neucodec encoding worker.

Claims shards from Supabase, downloads audio.tar + metadata.parquet from R2,
encodes all audio with neucodec, uploads neucodec_tokens.parquet back to R2.

Reports per-worker state to neucodec_workers table for monitoring.

Usage:
    python worker.py                           # auto-generate worker ID
    python worker.py --worker-id my-worker-01  # explicit worker ID
    python worker.py --max-shards 5            # process at most 5 shards then exit
    python worker.py --dry-run                 # claim but don't process
"""

import argparse
import io
import logging
import os
import queue
import socket
import tarfile
import tempfile
import threading
import time
import uuid
from pathlib import Path

import boto3
import numpy as np
import pandas as pd
import psycopg2
import soundfile as sf
import torch
import torchaudio.functional as AF

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger("neucodec-worker")

# ── Config from environment ──────────────────────────────────────────────────
DATABASE_URL = os.environ["DATABASE_URL"]
R2_ENDPOINT_URL = os.environ["R2_ENDPOINT_URL"]
R2_ACCESS_KEY_ID = os.environ["R2_ACCESS_KEY_ID"]
R2_SECRET_ACCESS_KEY = os.environ["R2_SECRET_ACCESS_KEY"]
R2_BUCKET = os.environ.get("R2_BUCKET_SFT_DATA", "finalsftdata")

HEARTBEAT_INTERVAL = 30  # seconds
STALE_CLAIM_TIMEOUT = 600  # 10 min - reclaim shards from dead workers
MAX_AUDIO_SAMPLES = 30 * 16000  # 30s at 16kHz — chunk longer clips to avoid OOM


def get_db(retries=10, base_delay=5):
    for attempt in range(retries):
        try:
            conn = psycopg2.connect(DATABASE_URL)
            conn.autocommit = True
            return conn
        except psycopg2.OperationalError as e:
            if attempt < retries - 1:
                delay = base_delay * (2 ** min(attempt, 4)) + time.time() % 5
                log.warning("DB connection failed (attempt %d/%d), retrying in %.0fs: %s",
                            attempt + 1, retries, delay, str(e)[:80])
                time.sleep(delay)
            else:
                raise


def get_s3():
    return boto3.client(
        "s3",
        endpoint_url=R2_ENDPOINT_URL,
        aws_access_key_id=R2_ACCESS_KEY_ID,
        aws_secret_access_key=R2_SECRET_ACCESS_KEY,
        region_name="auto",
    )


# ── Worker state (shared mutable state for heartbeat thread) ─────────────────

class WorkerState:
    """Thread-safe worker state for heartbeat reporting."""

    def __init__(self, worker_id: str, gpu_name: str, gpu_vram_mb: int):
        self.worker_id = worker_id
        self.hostname = socket.gethostname()
        self.gpu_name = gpu_name
        self.gpu_vram_mb = gpu_vram_mb
        self.lock = threading.Lock()
        # Mutable fields updated by main thread, read by heartbeat thread
        self.state = "idle"
        self.current_shard_id = None
        self.current_shard_prefix = None
        self.shards_completed = 0
        self.shards_failed = 0
        self.total_audio_hours = 0.0
        self.total_tokens = 0
        self.progress_pct = 0.0
        self.progress_segments = 0
        self.progress_total_segments = 0
        self.progress_rtf = 0.0
        self.progress_vram_gb = 0.0
        self.progress_phase = "idle"
        self.last_error = None
        self.last_error_at = None

    def set_phase(self, phase: str, shard_id=None, shard_prefix=None):
        with self.lock:
            self.state = "working"
            self.progress_phase = phase
            if shard_id is not None:
                self.current_shard_id = shard_id
            if shard_prefix is not None:
                self.current_shard_prefix = shard_prefix

    def set_progress(self, processed: int, total: int, rtf: float):
        with self.lock:
            self.progress_segments = processed
            self.progress_total_segments = total
            self.progress_pct = processed / total * 100 if total > 0 else 0
            self.progress_rtf = rtf
            self.progress_vram_gb = torch.cuda.memory_allocated() / 1e9

    def mark_shard_done(self, audio_hours: float, tokens: int):
        with self.lock:
            self.shards_completed += 1
            self.total_audio_hours += audio_hours
            self.total_tokens += tokens
            self.state = "idle"
            self.current_shard_id = None
            self.current_shard_prefix = None
            self.progress_pct = 0
            self.progress_phase = "idle"

    def mark_shard_failed(self, error: str):
        with self.lock:
            self.shards_failed += 1
            self.last_error = error[:500]
            self.last_error_at = "NOW()"
            self.state = "idle"
            self.current_shard_id = None
            self.current_shard_prefix = None
            self.progress_pct = 0
            self.progress_phase = "idle"

    def snapshot(self):
        with self.lock:
            return {
                "worker_id": self.worker_id,
                "hostname": self.hostname,
                "gpu_name": self.gpu_name,
                "gpu_vram_mb": self.gpu_vram_mb,
                "state": self.state,
                "current_shard_id": self.current_shard_id,
                "current_shard_prefix": self.current_shard_prefix,
                "shards_completed": self.shards_completed,
                "shards_failed": self.shards_failed,
                "total_audio_hours": self.total_audio_hours,
                "total_tokens": self.total_tokens,
                "progress_pct": self.progress_pct,
                "progress_segments": self.progress_segments,
                "progress_total_segments": self.progress_total_segments,
                "progress_rtf": self.progress_rtf,
                "progress_vram_gb": self.progress_vram_gb,
                "progress_phase": self.progress_phase,
                "last_error": self.last_error,
            }


# ── Heartbeat ────────────────────────────────────────────────────────────────

def heartbeat_loop(wstate: WorkerState, stop_event: threading.Event):
    """Periodically write worker state to both neucodec_workers and neucodec_shards."""
    while not stop_event.wait(HEARTBEAT_INTERVAL):
        try:
            snap = wstate.snapshot()
            conn = get_db()
            cur = conn.cursor()

            # Upsert worker state
            cur.execute("""
                INSERT INTO neucodec_workers (
                    worker_id, hostname, gpu_name, gpu_vram_mb, state,
                    current_shard_id, current_shard_prefix,
                    shards_completed, shards_failed,
                    total_audio_hours, total_tokens,
                    progress_pct, progress_segments, progress_total_segments,
                    progress_rtf, progress_vram_gb, progress_phase,
                    last_heartbeat, last_error
                ) VALUES (
                    %(worker_id)s, %(hostname)s, %(gpu_name)s, %(gpu_vram_mb)s, %(state)s,
                    %(current_shard_id)s, %(current_shard_prefix)s,
                    %(shards_completed)s, %(shards_failed)s,
                    %(total_audio_hours)s, %(total_tokens)s,
                    %(progress_pct)s, %(progress_segments)s, %(progress_total_segments)s,
                    %(progress_rtf)s, %(progress_vram_gb)s, %(progress_phase)s,
                    NOW(), %(last_error)s
                ) ON CONFLICT (worker_id) DO UPDATE SET
                    state = EXCLUDED.state,
                    current_shard_id = EXCLUDED.current_shard_id,
                    current_shard_prefix = EXCLUDED.current_shard_prefix,
                    shards_completed = EXCLUDED.shards_completed,
                    shards_failed = EXCLUDED.shards_failed,
                    total_audio_hours = EXCLUDED.total_audio_hours,
                    total_tokens = EXCLUDED.total_tokens,
                    progress_pct = EXCLUDED.progress_pct,
                    progress_segments = EXCLUDED.progress_segments,
                    progress_total_segments = EXCLUDED.progress_total_segments,
                    progress_rtf = EXCLUDED.progress_rtf,
                    progress_vram_gb = EXCLUDED.progress_vram_gb,
                    progress_phase = EXCLUDED.progress_phase,
                    last_heartbeat = NOW(),
                    last_error = EXCLUDED.last_error
            """, snap)

            # Also update shard claimed_at for stale detection
            if snap["current_shard_id"]:
                cur.execute(
                    "UPDATE neucodec_shards SET claimed_at = NOW() WHERE id = %s",
                    (snap["current_shard_id"],),
                )

            cur.close()
            conn.close()
        except Exception as e:
            log.debug("Heartbeat failed: %s", e)


# ── Shard claim & status ────────────────────────────────────────────────────

def claim_shard(conn, worker_id: str):
    cur = conn.cursor()
    cur.execute("SELECT * FROM claim_neucodec_shard(%s)", (worker_id,))
    row = cur.fetchone()
    cur.close()
    if row:
        return {"id": row[0], "prefix": row[1], "dataset": row[2], "language": row[3]}
    return None


def update_shard_status(conn, shard_id: int, status: str, **kwargs):
    sets = ["status = %s"]
    vals = [status]
    for k, v in kwargs.items():
        sets.append(f"{k} = %s")
        vals.append(v)
    vals.append(shard_id)
    cur = conn.cursor()
    cur.execute(f"UPDATE neucodec_shards SET {', '.join(sets)} WHERE id = %s", vals)
    cur.close()


def reclaim_stale(conn):
    cur = conn.cursor()
    cur.execute("""
        UPDATE neucodec_shards
        SET status = 'pending', worker_id = NULL, claimed_at = NULL
        WHERE status IN ('claimed', 'processing')
          AND claimed_at < NOW() - INTERVAL '%s seconds'
        RETURNING id, shard_prefix
    """, (STALE_CLAIM_TIMEOUT,))
    reclaimed = cur.fetchall()
    cur.close()
    if reclaimed:
        log.warning("Reclaimed %d stale shards: %s", len(reclaimed), [r[1] for r in reclaimed])


# ── Download ─────────────────────────────────────────────────────────────────

def download_shard(s3, shard_prefix: str, work_dir: Path):
    tar_path = work_dir / "audio.tar"
    meta_path = work_dir / "metadata.parquet"

    log.info("Downloading %s ...", shard_prefix)
    t0 = time.time()

    from boto3.s3.transfer import TransferConfig
    config = TransferConfig(
        multipart_threshold=64 * 1024 * 1024,
        max_concurrency=8,
        multipart_chunksize=64 * 1024 * 1024,
    )
    s3.download_file(R2_BUCKET, f"{shard_prefix}/audio.tar", str(tar_path), Config=config)
    s3.download_file(R2_BUCKET, f"{shard_prefix}/metadata.parquet", str(meta_path))

    tar_size = tar_path.stat().st_size
    dl_time = time.time() - t0
    log.info("Downloaded %.1f GB in %.0fs (%.1f MB/s)", tar_size / 1e9, dl_time, tar_size / dl_time / 1e6)
    return tar_path, meta_path, tar_size


# ── Encode ───────────────────────────────────────────────────────────────────

MODEL_CACHE_TAR = "neucodec_models.tar.gz"
HF_CACHE_DIR = os.path.expanduser("~/.cache/huggingface/hub")


def ensure_model_weights():
    """Download model weights from R2 if not already cached locally."""
    marker = os.path.join(HF_CACHE_DIR, "models--neuphonic--neucodec", "refs")
    if os.path.isdir(marker):
        log.info("Model weights already cached")
        return

    log.info("Downloading model weights from R2...")
    t0 = time.time()
    s3 = get_s3()
    tar_path = f"/tmp/{MODEL_CACHE_TAR}"

    from boto3.s3.transfer import TransferConfig
    config = TransferConfig(multipart_threshold=64*1024*1024, max_concurrency=8, multipart_chunksize=64*1024*1024)
    s3.download_file(R2_BUCKET, MODEL_CACHE_TAR, tar_path, Config=config)

    os.makedirs(HF_CACHE_DIR, exist_ok=True)
    import subprocess
    subprocess.run(["tar", "xzf", tar_path, "-C", HF_CACHE_DIR], check=True)
    os.remove(tar_path)
    log.info("Model weights downloaded and extracted in %.0fs", time.time() - t0)


def load_codec():
    from neucodec import NeuCodec
    import torch.nn as nn

    ensure_model_weights()

    log.info("Loading neucodec model...")
    t0 = time.time()
    codec = NeuCodec.from_pretrained("neuphonic/neucodec")

    orig_layers = len(codec.semantic_model.encoder.layers)
    codec.semantic_model.encoder.layers = nn.ModuleList(
        list(codec.semantic_model.encoder.layers)[:17]
    )
    log.info("Truncated semantic model: %d → %d layers", orig_layers, 17)

    _ = codec.eval().to("cuda")
    log.info("Neucodec loaded in %.1fs, VRAM: %.1f GB", time.time() - t0, torch.cuda.memory_allocated() / 1e9)
    return codec


def _init_gpu_fbank(codec):
    """Pre-compute GPU tensors for mel spectrogram extraction."""
    fe = codec.feature_extractor
    mel_filters = torch.from_numpy(fe.mel_filters.T).double().to("cuda")  # (80, 257)
    window = torch.from_numpy(fe.window).double().to("cuda")  # (400,)
    return mel_filters, window


def _gpu_fbank(audio_gpu, mel_filters, window):
    """GPU-based fbank: replaces CPU numpy spectrogram. 20x faster, bit-equivalent."""
    x = audio_gpu.double().squeeze() * (2**15)  # Kaldi compliance
    frame_length, hop_length, fft_length = 400, 160, 512
    num_frames = int(1 + (x.shape[0] - frame_length) // hop_length)
    if num_frames <= 0:
        return None
    idx = torch.arange(num_frames, device="cuda").unsqueeze(1) * hop_length + \
          torch.arange(frame_length, device="cuda").unsqueeze(0)
    frames = x[idx]
    frames = frames - frames.mean(dim=1, keepdim=True)  # remove DC offset
    preemph = frames.clone()
    preemph[:, 1:] = frames[:, 1:] - 0.97 * frames[:, :-1]
    preemph[:, 0] = frames[:, 0] * 0.03
    preemph *= window
    padded = torch.zeros(num_frames, fft_length, device="cuda", dtype=torch.float64)
    padded[:, :frame_length] = preemph
    power_spec = torch.fft.rfft(padded).abs() ** 2
    mel_spec = torch.clamp(mel_filters @ power_spec.T, min=1.192092955078125e-07)
    features = torch.log(mel_spec).T.float()  # (num_frames, 80)
    mean = features.mean(dim=0, keepdim=True)
    var = features.var(dim=0, unbiased=True, keepdim=True)
    features = (features - mean) / torch.sqrt(var + 1e-7)
    nf = features.shape[0]
    rem = nf % 2
    if rem:
        features = features[:nf - rem]
    return features.reshape(-1, 160).unsqueeze(0)  # (1, T, 160)


def encode_shard(codec, tar_path: Path, meta_path: Path, wstate: WorkerState):
    metadata = pd.read_parquet(meta_path)
    mel_filters, window = _init_gpu_fbank(codec)

    tf = tarfile.open(str(tar_path), "r")
    members = [m for m in tf.getmembers() if m.name.endswith(".flac")]
    total_members = len(members)
    log.info("Encoding %d audio files...", total_members)

    member_to_seg = {}
    for m in members:
        seg_id = m.name.replace(".flac", "").split("/")[-1]
        member_to_seg[m.name] = seg_id

    # Audio loading thread — reads FLAC from tar, pushes raw audio to queue
    audio_queue = queue.Queue(maxsize=16)

    TARGET_SR = 16000

    def audio_loader():
        for member in members:
            try:
                f = tf.extractfile(member)
                audio_data, sr = sf.read(io.BytesIO(f.read()))
                # Resample to 16kHz if needed (e.g. final-export is 48kHz)
                if sr != TARGET_SR:
                    audio_t = torch.from_numpy(audio_data).float()
                    audio_t = AF.resample(audio_t, sr, TARGET_SR)
                    audio_data = audio_t.numpy()
                    sr = TARGET_SR
                seg_id = member_to_seg[member.name]
                audio_queue.put((seg_id, audio_data, sr))
            except Exception as e:
                seg_id = member_to_seg.get(member.name, member.name)
                log.warning("Failed to read/extract %s: %s", seg_id, e)
                audio_queue.put((seg_id, None, 0))
        audio_queue.put(None)

    loader_thread = threading.Thread(target=audio_loader, daemon=True)
    loader_thread.start()

    results = []
    total_audio_s = 0.0
    total_tokens = 0
    processed = 0
    failed = 0
    t_start = time.time()

    while True:
        item = audio_queue.get()
        if item is None:
            break

        seg_id, audio_data, sr = item
        if audio_data is None:
            failed += 1
            processed += 1
            continue

        try:
            dur = len(audio_data) / sr
            if len(audio_data) > MAX_AUDIO_SAMPLES:
                chunks = [audio_data[s:s + MAX_AUDIO_SAMPLES]
                          for s in range(0, len(audio_data), MAX_AUDIO_SAMPLES)]
            else:
                chunks = [audio_data]

            all_codes = []
            for chunk in chunks:
                y_gpu = torch.tensor(chunk, dtype=torch.float32, device="cuda")[None, None]
                feats_gpu = _gpu_fbank(y_gpu.squeeze(), mel_filters, window)
                if feats_gpu is None:
                    continue

                with torch.no_grad():
                    acoustic_emb = codec.CodecEnc(y_gpu).transpose(1, 2)
                    semantic_output = (
                        codec.semantic_model(feats_gpu).hidden_states[16].transpose(1, 2)
                    )
                    semantic_encoded = codec.SemanticEncoder_module(semantic_output)

                    if acoustic_emb.shape[-1] != semantic_encoded.shape[-1]:
                        ml = min(acoustic_emb.shape[-1], semantic_encoded.shape[-1])
                        acoustic_emb = acoustic_emb[:, :, :ml]
                        semantic_encoded = semantic_encoded[:, :, :ml]

                    concat_emb = torch.cat([semantic_encoded, acoustic_emb], dim=1)
                    concat_emb = codec.fc_prior(concat_emb.transpose(1, 2)).transpose(1, 2)
                    _, codes, _ = codec.generator(concat_emb, vq=True)

                all_codes.append(codes[0, 0].cpu().numpy().astype(np.uint16))
                del feats_gpu, y_gpu, acoustic_emb, semantic_output, semantic_encoded
                del concat_emb, codes

            if all_codes:
                token_array = np.concatenate(all_codes) if len(all_codes) > 1 else all_codes[0]
                results.append({
                    "segment_id": seg_id,
                    "neucodec_tokens": token_array.tobytes(),
                    "token_count": len(token_array),
                })
                total_audio_s += dur
                total_tokens += len(token_array)
        except Exception as e:
            log.warning("Encode failed for %s: %s", seg_id, e)
            failed += 1

        processed += 1
        if processed % 500 == 0:
            torch.cuda.empty_cache()
            elapsed = time.time() - t_start
            rtf = total_audio_s / elapsed if elapsed > 0 else 0
            wstate.set_progress(processed, total_members, rtf)
        if processed % 3000 == 0:
            elapsed = time.time() - t_start
            rtf = total_audio_s / elapsed if elapsed > 0 else 0
            vram_gb = torch.cuda.memory_allocated() / 1e9
            log.info(
                "[%d/%d] %.0f%% done, %.0fs elapsed, RTF=%.1fx, VRAM=%.1fGB",
                processed, total_members, processed / total_members * 100, elapsed, rtf, vram_gb,
            )

    loader_thread.join()
    tf.close()

    wall = time.time() - t_start
    rtf = total_audio_s / wall if wall > 0 else 0
    log.info(
        "Encoding done: %d ok, %d failed, %.1f hrs audio, %.0fs wall, RTF=%.1fx, %d tokens",
        processed - failed, failed, total_audio_s / 3600, wall, rtf, total_tokens,
    )

    df = pd.DataFrame(results)
    return df, total_audio_s, total_tokens, wall


# ── Upload ───────────────────────────────────────────────────────────────────

def upload_results(s3, shard_prefix: str, tokens_df: pd.DataFrame, work_dir: Path):
    out_path = work_dir / "neucodec_tokens.parquet"
    tokens_df.to_parquet(str(out_path), index=False)
    upload_key = f"{shard_prefix}/neucodec_tokens.parquet"
    file_size = out_path.stat().st_size
    log.info("Uploading %s (%.1f MB)...", upload_key, file_size / 1e6)
    s3.upload_file(str(out_path), R2_BUCKET, upload_key)
    log.info("Upload complete: %s", upload_key)
    return upload_key


# ── Main loop ────────────────────────────────────────────────────────────────

def process_one_shard(codec, s3, conn, shard: dict, wstate: WorkerState):
    shard_id = shard["id"]
    prefix = shard["prefix"]

    log.info("═" * 60)
    log.info("Processing shard %d: %s", shard_id, prefix)
    log.info("═" * 60)

    wstate.set_phase("downloading", shard_id=shard_id, shard_prefix=prefix)

    try:
        update_shard_status(conn, shard_id, "processing", started_at="NOW()")

        with tempfile.TemporaryDirectory(prefix="neucodec_") as work_dir:
            work_dir = Path(work_dir)

            tar_path, meta_path, tar_size = download_shard(s3, prefix, work_dir)

            wstate.set_phase("encoding")
            tokens_df, total_audio_s, total_tokens, encode_wall = encode_shard(
                codec, tar_path, meta_path, wstate
            )

            if tokens_df.empty:
                raise RuntimeError("No segments encoded successfully")

            wstate.set_phase("uploading")
            output_key = upload_results(s3, prefix, tokens_df, work_dir)

            update_shard_status(
                conn, shard_id, "completed", completed_at="NOW()",
                audio_tar_bytes=tar_size,
                segment_count=len(tokens_df),
                total_audio_seconds=total_audio_s,
                total_tokens=total_tokens,
                encode_wall_seconds=encode_wall,
                output_key=output_key,
            )
            wstate.mark_shard_done(total_audio_s / 3600, total_tokens)
            log.info("Shard %d completed successfully", shard_id)

    except Exception as e:
        log.exception("Shard %d failed: %s", shard_id, e)
        wstate.mark_shard_failed(str(e))
        try:
            update_shard_status(conn, shard_id, "failed", failed_at="NOW()", error_message=str(e)[:500])
        except Exception:
            pass


def main():
    parser = argparse.ArgumentParser(description="Neucodec encoding worker")
    parser.add_argument("--worker-id", default=None)
    parser.add_argument("--max-shards", type=int, default=0)
    parser.add_argument("--dry-run", action="store_true")
    args = parser.parse_args()

    worker_id = args.worker_id or f"{socket.gethostname()}-{uuid.uuid4().hex[:8]}"

    gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "NONE"
    gpu_vram = torch.cuda.get_device_properties(0).total_memory // (1024 * 1024) if torch.cuda.is_available() else 0
    log.info("Worker starting: %s", worker_id)
    log.info("GPU: %s (%d MB)", gpu_name, gpu_vram)

    codec = load_codec()
    with torch.no_grad():
        codec.encode_code(torch.randn(1, 1, 16000).cuda())
    log.info("Codec warmup complete")

    # Initialize worker state and start heartbeat thread
    wstate = WorkerState(worker_id, gpu_name, gpu_vram)
    stop_hb = threading.Event()
    hb_thread = threading.Thread(target=heartbeat_loop, args=(wstate, stop_hb), daemon=True)
    hb_thread.start()

    s3 = get_s3()

    # Brief DB connection for initial reclaim
    conn = get_db()
    reclaim_stale(conn)
    conn.close()

    shards_done = 0
    prefetch_result = [None]  # [shard, work_dir, tar_path, meta_path, tar_size] or None

    def prefetch_next(s3_client, worker_id_ref):
        """Download the next shard in background while current one encodes."""
        try:
            pf_conn = get_db()
            shard = claim_shard(pf_conn, worker_id_ref)
            if shard is None:
                pf_conn.close()
                prefetch_result[0] = None
                return
            update_shard_status(pf_conn, shard["id"], "processing", started_at="NOW()")
            pf_conn.close()  # Close DB during download
            work_dir = Path(tempfile.mkdtemp(prefix="neucodec_prefetch_"))
            tar_path, meta_path, tar_size = download_shard(s3_client, shard["prefix"], work_dir)
            prefetch_result[0] = (shard, work_dir, tar_path, meta_path, tar_size)
        except Exception as e:
            log.warning("Prefetch failed: %s", e)
            try:
                fc = get_db()
                update_shard_status(fc, shard["id"], "failed", failed_at="NOW()", error_message=str(e)[:500])
                fc.close()
            except Exception:
                pass
            prefetch_result[0] = None

    try:
        while True:
            conn = get_db()
            shard = claim_shard(conn, worker_id)
            if shard is None:
                conn.close()
                log.info("No more pending shards. Worker exiting.")
                break

            if args.dry_run:
                log.info("DRY RUN: shard %d: %s", shard["id"], shard["prefix"])
                update_shard_status(conn, shard["id"], "pending", worker_id=None, claimed_at=None)
                conn.close()
                shards_done += 1
                if args.max_shards > 0 and shards_done >= args.max_shards:
                    break
                continue

            # Download first shard synchronously
            wstate.set_phase("downloading", shard_id=shard["id"], shard_prefix=shard["prefix"])
            update_shard_status(conn, shard["id"], "processing", started_at="NOW()")
            conn.close()  # Close DB during download + encode
            work_dir = Path(tempfile.mkdtemp(prefix="neucodec_"))
            tar_path, meta_path, tar_size = download_shard(s3, shard["prefix"], work_dir)

            while True:
                # Start prefetching next shard while we encode current one
                prefetch_result[0] = None
                if args.max_shards == 0 or shards_done + 1 < args.max_shards:
                    prefetch_s3 = get_s3()
                    pf_thread = threading.Thread(
                        target=prefetch_next, args=(prefetch_s3, worker_id), daemon=True
                    )
                    pf_thread.start()
                else:
                    pf_thread = None

                # Encode current shard (no DB connection held!)
                wstate.set_phase("encoding", shard_id=shard["id"], shard_prefix=shard["prefix"])
                try:
                    tokens_df, total_audio_s, total_tokens, encode_wall = encode_shard(
                        codec, tar_path, meta_path, wstate
                    )
                    if tokens_df.empty:
                        raise RuntimeError("No segments encoded successfully")

                    wstate.set_phase("uploading")
                    output_key = upload_results(s3, shard["prefix"], tokens_df, work_dir)

                    # Brief DB connection for status update
                    conn = get_db()
                    update_shard_status(
                        conn, shard["id"], "completed", completed_at="NOW()",
                        audio_tar_bytes=tar_size, segment_count=len(tokens_df),
                        total_audio_seconds=total_audio_s, total_tokens=total_tokens,
                        encode_wall_seconds=encode_wall, output_key=output_key,
                    )
                    conn.close()
                    wstate.mark_shard_done(total_audio_s / 3600, total_tokens)
                    log.info("Shard %d completed", shard["id"])
                except Exception as e:
                    log.exception("Shard %d failed: %s", shard["id"], e)
                    wstate.mark_shard_failed(str(e))
                    try:
                        conn = get_db()
                        update_shard_status(conn, shard["id"], "failed", failed_at="NOW()", error_message=str(e)[:500])
                        conn.close()
                    except Exception:
                        pass

                # Cleanup current shard
                import shutil
                shutil.rmtree(work_dir, ignore_errors=True)
                shards_done += 1

                if args.max_shards > 0 and shards_done >= args.max_shards:
                    if pf_thread:
                        pf_thread.join(timeout=5)
                    break

                # Wait for prefetch to complete
                if pf_thread:
                    pf_thread.join()

                if prefetch_result[0] is None:
                    break  # no more shards

                # Swap to prefetched shard — NO download wait!
                shard, work_dir, tar_path, meta_path, tar_size = prefetch_result[0]
                log.info("Prefetched shard %d ready, zero download wait", shard["id"])

            if args.max_shards > 0 and shards_done >= args.max_shards:
                log.info("Reached max_shards=%d, exiting.", args.max_shards)
                break
    finally:
        # Mark worker as exited
        wstate.state = "exited"
        wstate.progress_phase = "exited"
        stop_hb.set()
        hb_thread.join(timeout=10)
        try:
            fc = get_db()
            cur = fc.cursor()
            cur.execute("""
                UPDATE neucodec_workers SET state='exited', progress_phase='exited', last_heartbeat=NOW()
                WHERE worker_id = %s
            """, (worker_id,))
            cur.close()
            fc.close()
        except Exception:
            pass

    log.info("Worker %s finished after %d shards", worker_id, shards_done)


if __name__ == "__main__":
    main()
