"""
Audio loading, resampling, and batch collation.
Loads FLAC segments from extracted tars, resamples to 16kHz,
and provides padded batches for model inference.
"""
from __future__ import annotations

import json
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional

import numpy as np
import torch
import torchaudio

from .config import AUDIO_SAMPLE_RATE, MAX_AUDIO_DURATION_S, MIN_AUDIO_DURATION_S

logger = logging.getLogger(__name__)


@dataclass
class SegmentData:
    """One audio segment with its metadata and transcription."""
    segment_file: str
    waveform: torch.Tensor          # [samples] mono float32
    duration_s: float
    sample_rate: int = AUDIO_SAMPLE_RATE
    # From transcription JSON
    gemini_transcription: str = ""
    gemini_tagged: str = ""
    gemini_lang: str = ""
    gemini_quality_score: float = 0.0
    speaker_info: dict = field(default_factory=dict)


def load_video_segments(
    work_dir: Path,
    video_id: str,
) -> tuple[dict, list[SegmentData]]:
    """
    Load all segments from an extracted transcribed tar.
    Returns (metadata_dict, list_of_segments).
    
    Expected structure:
      work_dir/{video_id}/metadata.json
      work_dir/{video_id}/segments/*.flac
      work_dir/{video_id}/transcriptions/*.json
    """
    video_dir = work_dir / video_id
    if not video_dir.exists():
        # Maybe files are directly in work_dir
        video_dir = work_dir

    metadata_path = _find_file(video_dir, "metadata.json")
    metadata = {}
    if metadata_path:
        metadata = json.loads(metadata_path.read_text())

    segments_dir = _find_dir(video_dir, "segments")
    transcriptions_dir = _find_dir(video_dir, "transcriptions")

    if not segments_dir:
        logger.warning(f"No segments/ dir found in {video_dir}")
        return metadata, []

    flac_paths = sorted(segments_dir.glob("*.flac"))
    if not flac_paths:
        logger.warning(f"No FLAC files in {segments_dir}")
        return metadata, []

    # Build transcription lookup: segment_stem -> transcription_dict
    transcription_map: dict[str, dict] = {}
    if transcriptions_dir:
        for json_path in transcriptions_dir.glob("*.json"):
            try:
                data = json.loads(json_path.read_text())
                transcription_map[json_path.stem] = data
            except Exception as e:
                logger.warning(f"Bad transcription JSON {json_path.name}: {e}")

    resampler_cache: dict[int, torchaudio.transforms.Resample] = {}
    segments: list[SegmentData] = []
    skipped = 0

    for flac_path in flac_paths:
        try:
            waveform, sr = torchaudio.load(str(flac_path))

            # Mono
            if waveform.shape[0] > 1:
                waveform = waveform.mean(dim=0, keepdim=True)
            waveform = waveform.squeeze(0)  # [samples]

            # Resample to 16kHz
            if sr != AUDIO_SAMPLE_RATE:
                if sr not in resampler_cache:
                    resampler_cache[sr] = torchaudio.transforms.Resample(sr, AUDIO_SAMPLE_RATE)
                waveform = resampler_cache[sr](waveform)

            duration_s = waveform.shape[0] / AUDIO_SAMPLE_RATE

            if duration_s < MIN_AUDIO_DURATION_S or duration_s > MAX_AUDIO_DURATION_S:
                skipped += 1
                continue

            # Match transcription
            stem = flac_path.stem
            tx = transcription_map.get(stem, {})

            segments.append(SegmentData(
                segment_file=flac_path.name,
                waveform=waveform,
                duration_s=duration_s,
                gemini_transcription=tx.get("transcription", ""),
                gemini_tagged=tx.get("tagged", ""),
                gemini_lang=tx.get("detected_language", metadata.get("language", "")),
                gemini_quality_score=tx.get("quality_score", 0.0),
                speaker_info=tx.get("speaker", {}),
            ))
        except Exception as e:
            logger.warning(f"Failed to load {flac_path.name}: {e}")
            skipped += 1

    if skipped:
        logger.info(f"Loaded {len(segments)} segments, skipped {skipped} (too short/long/corrupt)")
    else:
        logger.info(f"Loaded {len(segments)} segments")

    return metadata, segments


def collate_waveforms(
    waveforms: list[torch.Tensor],
    target_sr: int = AUDIO_SAMPLE_RATE,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Pad variable-length waveforms into a batch tensor.
    Returns (padded_batch [B, max_len], lengths [B]).
    """
    lengths = torch.tensor([w.shape[0] for w in waveforms], dtype=torch.long)
    max_len = lengths.max().item()

    batch = torch.zeros(len(waveforms), max_len, dtype=torch.float32)
    for i, w in enumerate(waveforms):
        batch[i, :w.shape[0]] = w

    return batch, lengths


def batch_segments(
    segments: list[SegmentData],
    batch_size: int,
) -> list[list[SegmentData]]:
    """Split segments into batches."""
    return [segments[i:i + batch_size] for i in range(0, len(segments), batch_size)]


def _find_file(root: Path, name: str) -> Optional[Path]:
    for p in root.rglob(name):
        if p.is_file():
            return p
    return None


def _find_dir(root: Path, name: str) -> Optional[Path]:
    for p in root.rglob(name):
        if p.is_dir():
            return p
    return None
