#!/usr/bin/env python3
"""Phase 2: Convert a single shard's audio.tar from 48kHz to 16kHz FLAC.

Streaming: reads from input tar, decodes FLAC, decimates 48k→16k (exact 3:1),
encodes back to FLAC, writes to output tar. No WAV intermediate on disk.

Atomic swap: writes .tmp, validates, swaps, deletes old.

Usage:
  python3 tools/phase2_convert_shard_audio16.py /path/to/shard_dir
  python3 tools/phase2_convert_shard_audio16.py /path/to/shard_dir --dry-run
"""

import argparse
import hashlib
import io
import json
import os
import sys
import tarfile
import time
from pathlib import Path

import soundfile

# Enforce single-threaded
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")


def convert_shard(shard_dir: Path, dry_run: bool = False) -> dict:
    """Convert one shard's audio.tar to 16kHz mono FLAC. Returns status dict."""
    audio_tar = shard_dir / "audio.tar"
    tmp_tar = shard_dir / "audio_16k.tar.tmp"
    bak_tar = shard_dir / "audio_48k.bak"
    status_file = shard_dir / "shard_conversion_status.json"

    # Check if already converted
    if status_file.exists():
        existing = json.loads(status_file.read_text())
        if existing.get("status") == "success":
            return {"status": "skipped", "reason": "already converted"}

    if not audio_tar.exists():
        return {"status": "error", "reason": f"audio.tar not found in {shard_dir}"}

    if dry_run:
        return {"status": "dry_run", "shard_dir": str(shard_dir)}

    t0 = time.time()
    input_count = 0
    output_count = 0
    total_input_secs = 0.0
    total_output_secs = 0.0
    errors = []
    index_rows = []

    try:
        with tarfile.open(audio_tar, "r") as tf_in, tarfile.open(
            tmp_tar, "w"
        ) as tf_out:
            for member in tf_in:
                if not member.isfile():
                    continue
                input_count += 1

                try:
                    f_in = tf_in.extractfile(member)
                    raw = f_in.read()
                    data, sr = soundfile.read(io.BytesIO(raw))

                    total_input_secs += len(data) / sr

                    # Convert to mono if needed
                    if len(data.shape) > 1:
                        data = data.mean(axis=1)

                    # Resample to 16kHz
                    if sr == 48000:
                        data_16k = data[::3]  # Exact 3:1 decimation
                    elif sr == 16000:
                        data_16k = data
                    elif sr % 16000 == 0:
                        ratio = sr // 16000
                        data_16k = data[::ratio]
                    else:
                        # General case: use soxr
                        import soxr

                        data_16k = soxr.resample(data, sr, 16000)

                    total_output_secs += len(data_16k) / 16000

                    # Encode to FLAC
                    buf = io.BytesIO()
                    soundfile.write(buf, data_16k, 16000, format="FLAC")
                    flac_bytes = buf.getvalue()

                    # Compute SHA-256
                    sha = hashlib.sha256(flac_bytes).hexdigest()

                    # Write to output tar
                    info = tarfile.TarInfo(name=member.name)
                    info.size = len(flac_bytes)
                    info.mtime = member.mtime
                    tf_out.addfile(info, io.BytesIO(flac_bytes))

                    index_rows.append(
                        {
                            "member_name": member.name,
                            "nbytes": len(flac_bytes),
                            "duration_s": round(len(data_16k) / 16000, 4),
                            "sample_rate_hz": 16000,
                            "channels": 1,
                            "sha256": sha,
                        }
                    )
                    output_count += 1

                except Exception as e:
                    errors.append({"member": member.name, "error": str(e)})

        # Validate: counts must match (minus errors)
        expected = input_count - len(errors)
        if output_count != expected:
            raise RuntimeError(
                f"Count mismatch: input={input_count} errors={len(errors)} "
                f"expected={expected} got={output_count}"
            )

        # Spot-check: decode 3 random files from new tar
        with tarfile.open(tmp_tar, "r") as tf_check:
            members = [m for m in tf_check.getmembers() if m.isfile()]
            check_indices = [0, len(members) // 2, len(members) - 1]
            for idx in check_indices:
                if idx < len(members):
                    m = members[idx]
                    f = tf_check.extractfile(m)
                    data, sr = soundfile.read(io.BytesIO(f.read()))
                    assert sr == 16000, f"Spot-check: {m.name} sr={sr}"
                    assert len(data.shape) == 1, f"Spot-check: {m.name} not mono"

        # Atomic swap
        audio_tar.rename(bak_tar)
        tmp_tar.rename(audio_tar)

        # Verify new tar is readable
        with tarfile.open(audio_tar, "r") as tf_verify:
            verify_count = sum(1 for m in tf_verify if m.isfile())
        assert verify_count == output_count, (
            f"Post-swap verify: expected {output_count} got {verify_count}"
        )

        # Delete backup
        bak_tar.unlink()

        elapsed = time.time() - t0

        # Write index
        import pyarrow as pa
        import pyarrow.parquet as pq

        idx_table = pa.table(
            {col: [r[col] for r in index_rows] for col in index_rows[0].keys()}
        )
        pq.write_table(idx_table, shard_dir / "audio_16k_index.parquet")

        # Write status
        status = {
            "status": "success",
            "shard_dir": str(shard_dir),
            "input_count": input_count,
            "output_count": output_count,
            "error_count": len(errors),
            "errors": errors[:10],
            "input_hours": round(total_input_secs / 3600, 3),
            "output_hours": round(total_output_secs / 3600, 3),
            "elapsed_s": round(elapsed, 1),
            "audio_sec_per_sec": round(total_input_secs / max(elapsed, 0.01), 0),
        }
        with open(status_file, "w") as f:
            json.dump(status, f, indent=2)

        return status

    except Exception as e:
        # Cleanup on failure: remove tmp, restore backup if needed
        if tmp_tar.exists():
            tmp_tar.unlink()
        if bak_tar.exists() and not audio_tar.exists():
            bak_tar.rename(audio_tar)

        status = {
            "status": "error",
            "shard_dir": str(shard_dir),
            "error": str(e),
            "input_count": input_count,
            "output_count": output_count,
        }
        with open(status_file, "w") as f:
            json.dump(status, f, indent=2)
        return status


def main():
    parser = argparse.ArgumentParser(description="Convert shard audio to 16kHz")
    parser.add_argument("shard_dir", type=Path)
    parser.add_argument("--dry-run", action="store_true")
    args = parser.parse_args()

    result = convert_shard(args.shard_dir, dry_run=args.dry_run)
    print(json.dumps(result, indent=2))
    if result["status"] == "error":
        sys.exit(1)


if __name__ == "__main__":
    main()
