"""Unit tests for split_manifest.py — uses synthetic data only."""

import json
import subprocess
from pathlib import Path

PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent


def _make_manifest(tmp_path, rows):
    """Write rows to a JSONL file and return its path."""
    path = tmp_path / "input.jsonl"
    with open(path, "w") as f:
        for row in rows:
            f.write(json.dumps(row, ensure_ascii=False) + "\n")
    return path


def _sample_rows(lang, count):
    return [
        {
            "audio_filepath": f"/data/{lang}/audio.tar",
            "tar_member": f"seg_{i}.flac",
            "text": f"text {i}",
            "duration": 3.0,
            "lang": lang,
            "taskname": "asr",
            "source_lang": lang,
            "target_lang": lang,
        }
        for i in range(count)
    ]


def _run_split(tmp_path, manifest, extra_args=None):
    """Helper: run split_manifest and return (returncode, stdout, stderr, train_path, val_path)."""
    train = tmp_path / "train.jsonl"
    val = tmp_path / "val.jsonl"
    cmd = [
        "python3",
        "scripts/split_manifest.py",
        "--input",
        str(manifest),
        "--train-output",
        str(train),
        "--val-output",
        str(val),
        "--val-ratio",
        "0.1",
        "--seed",
        "42",
    ]
    if extra_args:
        cmd.extend(extra_args)
    result = subprocess.run(
        cmd,
        capture_output=True,
        text=True,
        cwd=str(PROJECT_ROOT),
    )
    return result, train, val


def test_split_deterministic(tmp_path):
    """Same seed should produce identical outputs."""
    manifest = _make_manifest(tmp_path, _sample_rows("en", 100) + _sample_rows("hi", 100))
    outputs = []
    for run in range(2):
        train = tmp_path / f"train_{run}.jsonl"
        val = tmp_path / f"val_{run}.jsonl"
        result = subprocess.run(
            [
                "python3",
                "scripts/split_manifest.py",
                "--input",
                str(manifest),
                "--train-output",
                str(train),
                "--val-output",
                str(val),
                "--val-ratio",
                "0.1",
                "--seed",
                "42",
            ],
            capture_output=True,
            text=True,
            cwd=str(PROJECT_ROOT),
        )
        assert result.returncode == 0, result.stderr
        outputs.append((train.read_text(), val.read_text()))

    assert outputs[0][0] == outputs[1][0], "Train outputs differ between runs"
    assert outputs[0][1] == outputs[1][1], "Val outputs differ between runs"


def test_split_preserves_total(tmp_path):
    """Train + val should equal input total."""
    rows = _sample_rows("en", 50) + _sample_rows("hi", 50)
    manifest = _make_manifest(tmp_path, rows)
    result, train, val = _run_split(tmp_path, manifest)
    assert result.returncode == 0, result.stderr

    train_lines = [ln for ln in train.read_text().strip().split("\n") if ln]
    val_lines = [ln for ln in val.read_text().strip().split("\n") if ln]
    assert len(train_lines) + len(val_lines) == 100


def test_split_stratifies_by_lang_default(tmp_path):
    """Default (--stratify-by-lang) ensures both langs appear in both splits."""
    rows = _sample_rows("en", 100) + _sample_rows("hi", 100)
    manifest = _make_manifest(tmp_path, rows)
    result, train, val = _run_split(tmp_path, manifest)
    assert result.returncode == 0, result.stderr

    train_langs = {json.loads(ln)["lang"] for ln in train.read_text().strip().split("\n") if ln}
    val_langs = {json.loads(ln)["lang"] for ln in val.read_text().strip().split("\n") if ln}
    assert "en" in train_langs and "hi" in train_langs
    assert "en" in val_langs and "hi" in val_langs


def test_split_stratify_by_lang_explicit_flag(tmp_path):
    """Explicit --stratify-by-lang should behave the same as default."""
    rows = _sample_rows("en", 100) + _sample_rows("hi", 100)
    manifest = _make_manifest(tmp_path, rows)
    result, train, val = _run_split(tmp_path, manifest, extra_args=["--stratify-by-lang"])
    assert result.returncode == 0, result.stderr

    val_langs = {json.loads(ln)["lang"] for ln in val.read_text().strip().split("\n") if ln}
    assert "en" in val_langs and "hi" in val_langs


def test_split_no_stratify_by_lang_flag(tmp_path):
    """--no-stratify-by-lang should still produce a valid split."""
    rows = _sample_rows("en", 100) + _sample_rows("hi", 100)
    manifest = _make_manifest(tmp_path, rows)
    result, train, val = _run_split(tmp_path, manifest, extra_args=["--no-stratify-by-lang"])
    assert result.returncode == 0, result.stderr

    train_lines = [ln for ln in train.read_text().strip().split("\n") if ln]
    val_lines = [ln for ln in val.read_text().strip().split("\n") if ln]
    assert len(train_lines) + len(val_lines) == 200


def test_split_no_stratify_legacy_alias(tmp_path):
    """Legacy --no-stratify alias should still work."""
    rows = _sample_rows("en", 100) + _sample_rows("hi", 100)
    manifest = _make_manifest(tmp_path, rows)
    result, train, val = _run_split(tmp_path, manifest, extra_args=["--no-stratify"])
    assert result.returncode == 0, result.stderr

    train_lines = [ln for ln in train.read_text().strip().split("\n") if ln]
    val_lines = [ln for ln in val.read_text().strip().split("\n") if ln]
    assert len(train_lines) + len(val_lines) == 200


def test_split_invalid_val_ratio_zero(tmp_path):
    """--val-ratio 0 should fail."""
    manifest = _make_manifest(tmp_path, _sample_rows("en", 10))
    result = subprocess.run(
        [
            "python3",
            "scripts/split_manifest.py",
            "--input",
            str(manifest),
            "--train-output",
            str(tmp_path / "t.jsonl"),
            "--val-output",
            str(tmp_path / "v.jsonl"),
            "--val-ratio",
            "0",
        ],
        capture_output=True,
        text=True,
        cwd=str(PROJECT_ROOT),
    )
    assert result.returncode != 0
    assert "--val-ratio must be >0 and <1" in result.stderr


def test_split_invalid_val_ratio_one(tmp_path):
    """--val-ratio 1 should fail."""
    manifest = _make_manifest(tmp_path, _sample_rows("en", 10))
    result = subprocess.run(
        [
            "python3",
            "scripts/split_manifest.py",
            "--input",
            str(manifest),
            "--train-output",
            str(tmp_path / "t.jsonl"),
            "--val-output",
            str(tmp_path / "v.jsonl"),
            "--val-ratio",
            "1.0",
        ],
        capture_output=True,
        text=True,
        cwd=str(PROJECT_ROOT),
    )
    assert result.returncode != 0
    assert "--val-ratio must be >0 and <1" in result.stderr


def test_split_invalid_val_ratio_negative(tmp_path):
    """--val-ratio -0.5 should fail."""
    manifest = _make_manifest(tmp_path, _sample_rows("en", 10))
    result = subprocess.run(
        [
            "python3",
            "scripts/split_manifest.py",
            "--input",
            str(manifest),
            "--train-output",
            str(tmp_path / "t.jsonl"),
            "--val-output",
            str(tmp_path / "v.jsonl"),
            "--val-ratio",
            "-0.5",
        ],
        capture_output=True,
        text=True,
        cwd=str(PROJECT_ROOT),
    )
    assert result.returncode != 0
    assert "--val-ratio must be >0 and <1" in result.stderr


def test_split_invalid_val_ratio_greater_than_one(tmp_path):
    """--val-ratio 1.5 should fail."""
    manifest = _make_manifest(tmp_path, _sample_rows("en", 10))
    result = subprocess.run(
        [
            "python3",
            "scripts/split_manifest.py",
            "--input",
            str(manifest),
            "--train-output",
            str(tmp_path / "t.jsonl"),
            "--val-output",
            str(tmp_path / "v.jsonl"),
            "--val-ratio",
            "1.5",
        ],
        capture_output=True,
        text=True,
        cwd=str(PROJECT_ROOT),
    )
    assert result.returncode != 0
    assert "--val-ratio must be >0 and <1" in result.stderr
