from __future__ import annotations

import asyncio
from pathlib import Path

import pyarrow as pa
import pyarrow.parquet as pq

from validations.config import ValidationConfig
from validations.recover_reference_store import RecoverReferenceStore, _s3_pool_size


def _write_parquet(path: Path, rows: list[dict]):
    if not rows:
        raise ValueError("rows required")
    columns = {key: [row.get(key) for row in rows] for key in rows[0]}
    pq.write_table(pa.table(columns), path, compression="zstd")


def test_recover_reference_store_reads_tx_rows_and_flag_summary(tmp_path: Path):
    tx_path = tmp_path / "transcription_results_recover.parquet"
    flags_path = tmp_path / "transcription_flags_recover.parquet"

    _write_parquet(tx_path, [
        {
            "video_id": "vid-a",
            "segment_file": "seg-001",
            "expected_language_hint": "hi",
            "detected_language": "hi",
            "transcription": "hello",
            "tagged": "hello",
            "quality_score": 0.9,
            "speaker_emotion": "neutral",
            "speaker_style": "plain",
            "speaker_pace": "normal",
            "speaker_accent": "north",
        },
        {
            "video_id": "vid-a",
            "segment_file": "seg-002",
            "expected_language_hint": "hi",
            "detected_language": "hi",
            "transcription": "world",
            "tagged": "world",
            "quality_score": 0.8,
            "speaker_emotion": "",
            "speaker_style": "",
            "speaker_pace": "",
            "speaker_accent": "",
        },
        {
            "video_id": "vid-b",
            "segment_file": "seg-101",
            "expected_language_hint": "en",
            "detected_language": "en",
            "transcription": "other",
            "tagged": "other",
            "quality_score": 0.7,
            "speaker_emotion": "",
            "speaker_style": "",
            "speaker_pace": "",
            "speaker_accent": "",
        },
    ])
    _write_parquet(flags_path, [
        {"segment_id": "extra-1", "flag_type": "timeout"},
        {"segment_id": "extra-1", "flag_type": "timeout"},
        {"segment_id": "extra-2", "flag_type": "error"},
        {"segment_id": "extra-3", "flag_type": "rate_limited"},
    ])

    config = ValidationConfig(
        mock_mode=True,
        recover_reference_mode="parquet",
        recover_tx_parquet_key=tx_path.name,
        recover_flags_parquet_key=flags_path.name,
    )
    store = RecoverReferenceStore(config, tmp_path)
    store.tx_path = tx_path
    store.flags_path = flags_path
    store.duckdb_path = tmp_path / "recover_reference.duckdb"
    store._open_duckdb()

    try:
        tx_rows = asyncio.run(store.fetch_tx_rows("vid-a"))
        assert [row["segment_file"] for row in tx_rows] == ["seg-001", "seg-002"]
        assert tx_rows[0]["transcription"] == "hello"

        summary = asyncio.run(store.fetch_flag_summary(["extra-1", "extra-2", "missing"]))
        assert summary == {
            "timeout": 1,
            "error": 1,
            "rate_limited": 0,
            "flagged_total": 2,
        }
    finally:
        asyncio.run(store.close())


def test_recover_reference_store_sizes_s3_pool_for_parallel_downloads():
    assert _s3_pool_size(16) >= 16
    assert _s3_pool_size(1) >= 1
