"""
Replay ledger sidecar helpers for recover workers.

Each successfully processed recover video uploads one compact per-video JSON.gz
object that lists every replayed child ID and whether it was:
  - already validated in prior runs
  - newly validated in this recover run
  - an extra replay-only child with no historical tx row
"""
from __future__ import annotations

import gzip
import json
import logging
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Optional

from .config import ValidationConfig

logger = logging.getLogger(__name__)

REPLAY_STATUS_VALIDATED_EXISTING = "validated_existing"
REPLAY_STATUS_VALIDATED_NEW = "validated_new"
REPLAY_STATUS_EXTRA_NO_TX = "extra_no_tx"
REPLAY_STATUS_TX_UNCLASSIFIED = "historical_tx_unclassified"


@dataclass(frozen=True)
class ReplayLedgerArtifact:
    key: str
    payload: dict


def build_replay_ledger_payload(
    *,
    video_id: str,
    tx_rows: list[dict],
    replayed_segment_ids: list[str],
    matched_tx_ids: list[str],
    validated_segment_ids: set[str],
    extra_regen_ids: list[str],
    flag_summary: dict,
    worker_id: str,
    missing_tx_ids: Optional[list[str]] = None,
    missing_parent_files: Optional[list[str]] = None,
) -> dict:
    tx_by_id = {
        row["segment_file"]: row
        for row in tx_rows
        if row.get("segment_file")
    }
    replayed_ids = sorted(set(replayed_segment_ids))
    matched_ids = set(matched_tx_ids)
    validated_ids = set(validated_segment_ids)
    extra_ids = set(extra_regen_ids)

    status_counts = {
        REPLAY_STATUS_VALIDATED_EXISTING: 0,
        REPLAY_STATUS_VALIDATED_NEW: 0,
        REPLAY_STATUS_EXTRA_NO_TX: 0,
        REPLAY_STATUS_TX_UNCLASSIFIED: 0,
    }
    entries: list[dict] = []

    for seg_id in replayed_ids:
        row = tx_by_id.get(seg_id)
        if seg_id in extra_ids:
            replay_status = REPLAY_STATUS_EXTRA_NO_TX
        elif seg_id in validated_ids:
            replay_status = REPLAY_STATUS_VALIDATED_EXISTING
        elif seg_id in matched_ids:
            replay_status = REPLAY_STATUS_VALIDATED_NEW
        else:
            replay_status = REPLAY_STATUS_TX_UNCLASSIFIED

        status_counts[replay_status] += 1
        entries.append({
            "segment_file": seg_id,
            "replay_status": replay_status,
            "has_historical_tx_row": row is not None,
            "historical_lang": _historical_lang(row),
            "historical_transcription": row.get("transcription") if row else None,
            "historical_tagged": row.get("tagged") if row else None,
            "historical_quality_score": float(row.get("quality_score") or 0.0) if row else None,
        })

    missing_tx_list = sorted(set(missing_tx_ids or []))
    missing_parent_list = sorted(set(missing_parent_files or []))
    return {
        "schema_version": 1,
        "video_id": video_id,
        "worker_id": worker_id,
        "generated_at": datetime.now(timezone.utc).isoformat(),
        "summary": {
            "replayed_segments": len(replayed_ids),
            "historical_tx_rows": len(tx_by_id),
            "validated_existing_segments": status_counts[REPLAY_STATUS_VALIDATED_EXISTING],
            "validated_new_segments": status_counts[REPLAY_STATUS_VALIDATED_NEW],
            "extra_no_tx_segments": status_counts[REPLAY_STATUS_EXTRA_NO_TX],
            "historical_tx_unclassified_segments": status_counts[REPLAY_STATUS_TX_UNCLASSIFIED],
            "missing_tx_segments": len(missing_tx_list),
            "missing_parent_files": missing_parent_list,
            "extra_timeout_segments": int(flag_summary.get("timeout", 0) or 0),
            "extra_error_segments": int(flag_summary.get("error", 0) or 0),
            "extra_rate_limited_segments": int(flag_summary.get("rate_limited", 0) or 0),
            "extra_flagged_segments": int(flag_summary.get("flagged_total", 0) or 0),
        },
        "missing_tx_ids": missing_tx_list,
        "entries": entries,
    }


class ReplayLedgerWriter:
    def __init__(self, config: ValidationConfig):
        self.config = config
        self._s3 = None

    def upload(self, video_id: str, payload: dict) -> ReplayLedgerArtifact:
        key = self.object_key(video_id)
        if self.config.mock_mode or self.config.r2_skip_upload:
            logger.info(f"[MOCK/SKIP] Would upload replay ledger -> s3://{self.config.r2_bucket_output}/{key}")
            return ReplayLedgerArtifact(key=key, payload=payload)

        body = gzip.compress(
            json.dumps(payload, ensure_ascii=False, separators=(",", ":")).encode("utf-8")
        )
        s3 = self._get_s3()
        s3.put_object(
            Bucket=self.config.r2_bucket_output,
            Key=key,
            Body=body,
            ContentType="application/json",
            ContentEncoding="gzip",
        )
        logger.info(f"Uploaded replay ledger -> s3://{self.config.r2_bucket_output}/{key}")
        return ReplayLedgerArtifact(key=key, payload=payload)

    def object_key(self, video_id: str) -> str:
        prefix = self.config.recover_replay_ledger_prefix.strip("/")
        if prefix:
            return f"{prefix}/{video_id}.json.gz"
        return f"{video_id}.json.gz"

    def _get_s3(self):
        if self._s3 is None:
            import boto3

            self._s3 = boto3.client(
                "s3",
                endpoint_url=self.config.r2_endpoint_url,
                aws_access_key_id=self.config.r2_access_key_id,
                aws_secret_access_key=self.config.r2_secret_access_key,
                region_name="auto",
            )
        return self._s3


def _historical_lang(row: Optional[dict]) -> Optional[str]:
    if not row:
        return None
    return row.get("detected_language") or row.get("expected_language_hint") or None
