"""
Recovery loader: rebuild validation-ready segments from raw `1-cleaned-data`
audio and historical transcription rows.

This is the core primitive for the recover path:
  - replay `audio_polish` from raw parent FLACs
  - regenerate deterministic child segment IDs
  - intersect them with historical `transcription_results`
  - return SegmentData objects ready for the validation models
"""
from __future__ import annotations

import json
import logging
import re
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional

import numpy as np
import torch

from src.audio_polish import polish_segment

from .audio_loader import SegmentData

logger = logging.getLogger(__name__)

_SPLIT_SUFFIX_RE = re.compile(r"_split\d+$")


def parent_segment_file(segment_id: str) -> str:
    """Map a child segment ID back to its raw parent FLAC filename."""
    return _SPLIT_SUFFIX_RE.sub("", segment_id)


def replay_segment_id(original_file: str, was_split: bool, split_index: int) -> str:
    """Rebuild the segment ID exactly like `src.pipeline` does."""
    if was_split:
        return f"{original_file}_split{split_index}"
    return original_file


@dataclass
class RecoveryLoadResult:
    metadata: dict
    segments: list[SegmentData] = field(default_factory=list)
    matched_tx_ids: list[str] = field(default_factory=list)
    missing_tx_ids: list[str] = field(default_factory=list)
    extra_regen_ids: list[str] = field(default_factory=list)
    missing_parent_files: list[str] = field(default_factory=list)


def load_recover_segments(
    work_dir: Path,
    video_id: str,
    tx_rows: list[dict],
    target_segment_ids: Optional[set[str]] = None,
) -> RecoveryLoadResult:
    """
    Rebuild validation-ready SegmentData from raw extracted tar + tx rows.

    `tx_rows` should come from `transcription_results` and include at least:
      - segment_file
      - transcription
      - tagged
      - detected_language
      - quality_score

    If `target_segment_ids` is provided, only those historical IDs are returned.
    All replay-only IDs for the processed parents are surfaced via `extra_regen_ids`
    so the caller can log salvage candidates separately.
    """
    video_dir = work_dir / video_id
    if not video_dir.exists():
        video_dir = work_dir

    metadata_path = _find_file(video_dir, "metadata.json")
    metadata = {}
    if metadata_path:
        metadata = json.loads(metadata_path.read_text())

    segments_dir = _find_dir(video_dir, "segments")
    if not segments_dir:
        logger.warning(f"[{video_id}] No raw segments/ dir found in {video_dir}")
        return RecoveryLoadResult(metadata=metadata)

    tx_map = {row["segment_file"]: row for row in tx_rows if row.get("segment_file")}
    requested_ids = set(target_segment_ids) if target_segment_ids is not None else set(tx_map)

    # If the caller passes IDs not present in tx_rows, surface them as missing.
    unknown_target_ids = {seg_id for seg_id in requested_ids if seg_id not in tx_map}
    requested_ids = {seg_id for seg_id in requested_ids if seg_id in tx_map}

    parent_to_requested_ids: dict[str, set[str]] = defaultdict(set)
    parent_to_all_tx_ids: dict[str, set[str]] = defaultdict(set)
    for seg_id in tx_map:
        parent = parent_segment_file(seg_id)
        parent_to_all_tx_ids[parent].add(seg_id)
    for seg_id in requested_ids:
        parent = parent_segment_file(seg_id)
        parent_to_requested_ids[parent].add(seg_id)

    raw_paths = {path.name: path for path in segments_dir.glob("*.flac")}

    result = RecoveryLoadResult(metadata=metadata)
    matched: set[str] = set()
    extras: set[str] = set()

    for parent_file in sorted(parent_to_requested_ids):
        raw_path = raw_paths.get(parent_file)
        if raw_path is None:
            result.missing_parent_files.append(parent_file)
            continue

        try:
            polished = polish_segment(raw_path)
        except Exception as e:  # pragma: no cover - defensive logging path
            logger.warning(f"[{video_id}] Failed to replay {parent_file}: {e}")
            result.missing_parent_files.append(parent_file)
            continue

        parent_tx_ids = parent_to_all_tx_ids.get(parent_file, set())
        regen_ids_for_parent: list[str] = []

        for seg in polished:
            if seg.trim_meta.discarded:
                continue

            seg_id = replay_segment_id(
                seg.trim_meta.original_file,
                seg.trim_meta.was_split,
                seg.trim_meta.split_index,
            )
            regen_ids_for_parent.append(seg_id)

            if seg_id not in requested_ids:
                continue

            row = tx_map[seg_id]
            result.segments.append(SegmentData(
                segment_file=seg_id,
                waveform=torch.from_numpy(np.ascontiguousarray(seg.audio).copy()),
                duration_s=float(seg.trim_meta.final_duration_ms) / 1000.0,
                gemini_transcription=row.get("transcription", "") or "",
                gemini_tagged=row.get("tagged", "") or "",
                gemini_lang=row.get("detected_language", "") or row.get("expected_language_hint", "") or "",
                gemini_quality_score=float(row.get("quality_score") or 0.0),
                speaker_info=_speaker_info_from_row(row),
            ))
            matched.add(seg_id)

        for seg_id in regen_ids_for_parent:
            if seg_id not in parent_tx_ids:
                extras.add(seg_id)

    result.matched_tx_ids = sorted(matched)
    result.missing_tx_ids = sorted((requested_ids - matched) | unknown_target_ids)
    result.extra_regen_ids = sorted(extras)
    return result


def _speaker_info_from_row(row: dict) -> dict:
    speaker = {
        "emotion": row.get("speaker_emotion", "") or "",
        "speaking_style": row.get("speaker_style", "") or "",
        "pace": row.get("speaker_pace", "") or "",
        "accent": row.get("speaker_accent", "") or "",
    }
    return {k: v for k, v in speaker.items() if v}


def _find_file(root: Path, name: str) -> Optional[Path]:
    for path in root.rglob(name):
        if path.is_file():
            return path
    return None


def _find_dir(root: Path, name: str) -> Optional[Path]:
    for path in root.rglob(name):
        if path.is_dir():
            return path
    return None
