from __future__ import annotations

import json
from pathlib import Path

import numpy as np
import pyarrow.parquet as pq

from src.final_export_common import (
    build_pack_artifacts,
    compute_audio_metrics,
    segment_audio_member_name,
)


def test_segment_audio_member_name_handles_split_suffix() -> None:
    assert segment_audio_member_name("SPEAKER_00_0123_10.00-18.42.flac") == "SPEAKER_00_0123_10.00-18.42.flac"
    assert segment_audio_member_name("SPEAKER_00_0123_10.00-18.42.flac_split0") == "SPEAKER_00_0123_10.00-18.42.flac_split0.flac"


def test_compute_audio_metrics_has_expected_keys() -> None:
    sr = 16000
    duration_s = 2.0
    t = np.linspace(0, duration_s, int(sr * duration_s), endpoint=False)
    audio = (0.25 * np.sin(2 * np.pi * 220 * t)).astype(np.float32)
    metrics = compute_audio_metrics(audio, sr, "hello world")
    assert metrics["duration_s"] > 1.9
    assert metrics["sample_rate_hz"] == sr
    assert metrics["rms_dbfs"] is not None
    assert metrics["peak_dbfs"] is not None
    assert metrics["chars_per_sec"] > 0
    assert metrics["words_per_sec"] > 0


def test_build_pack_artifacts_writes_manifest_and_sidecars(tmp_path: Path) -> None:
    metadata_rows = [
        {
            "video_id": "video1",
            "segment_id": "seg_a",
            "segment_language": "hi",
            "transcription_mixed": "namaste",
            "duration_s": 2.0,
        },
        {
            "video_id": "video2",
            "segment_id": "seg_b",
            "segment_language": "hi",
            "transcription_mixed": "hello",
            "duration_s": 3.0,
        },
    ]
    audio_rows = [
        {
            "video_id": "video1",
            "segment_id": "seg_a",
            "tar_member_name": "seg_a.flac",
            "flac_bytes": b"fake_flac_a",
            "flac_sha256": "sha_a",
            "audio_duration_s": 2.0,
        },
        {
            "video_id": "video2",
            "segment_id": "seg_b",
            "tar_member_name": "seg_b.flac",
            "flac_bytes": b"fake_flac_b",
            "flac_sha256": "sha_b",
            "audio_duration_s": 3.0,
        },
    ]
    artifacts = build_pack_artifacts(
        pack_dir=tmp_path / "pack",
        manifest_name="manifest.json",
        metadata_rows=metadata_rows,
        audio_rows=audio_rows,
        manifest_payload={
            "shard_id": "hi_shard_000001",
            "language": "hi",
            "segment_count": 2,
            "video_count": 2,
        },
    )
    manifest = json.loads(artifacts.manifest_path.read_text())
    assert manifest["metadata_row_count"] == 2
    assert manifest["audio_index_row_count"] == 2
    assert manifest["audio_tar_member_count"] == 2
    assert manifest["sum_flac_bytes"] == len(b"fake_flac_a") + len(b"fake_flac_b")
    assert manifest["metadata_size_bytes"] == artifacts.metadata_path.stat().st_size
    assert manifest["audio_tar_size_bytes"] == artifacts.audio_tar_path.stat().st_size
    assert manifest["audio_index_size_bytes"] == artifacts.audio_index_path.stat().st_size

    metadata_back = pq.read_table(artifacts.metadata_path).to_pylist()
    audio_index_back = pq.read_table(artifacts.audio_index_path).to_pylist()
    assert [row["segment_id"] for row in metadata_back] == ["seg_a", "seg_b"]
    assert [row["segment_id"] for row in audio_index_back] == ["seg_a", "seg_b"]
