"""VAD-aware audio segmentation using Silero-VAD.

Produces speech segments between 3-30s with good duration distribution.
Cuts only at silence boundaries to avoid mid-word splits.
"""

from __future__ import annotations

import logging
import random
from dataclasses import dataclass
from pathlib import Path

import torch
import torchaudio

from codecbench.pipeline.config import VADConfig

logger = logging.getLogger(__name__)


@dataclass
class Segment:
    start_s: float
    end_s: float
    audio: torch.Tensor  # [1, T] at target_sr

    @property
    def duration_s(self) -> float:
        return self.end_s - self.start_s


import threading

_vad_lock = threading.Lock()
_vad_models: dict[int, tuple] = {}  # thread_id -> (model, utils)


def _load_vad():
    """Get or create a per-thread VAD model instance.

    Silero-VAD has internal RNN state that isn't thread-safe.
    Each thread gets its own model to avoid state corruption.
    """
    tid = threading.get_ident()
    if tid not in _vad_models:
        with _vad_lock:
            if tid not in _vad_models:
                model, utils = torch.hub.load(
                    repo_or_dir="snakers4/silero-vad",
                    model="silero_vad",
                    trust_repo=True,
                )
                _vad_models[tid] = (model, utils)
    return _vad_models[tid]


def get_speech_timestamps(
    wav: torch.Tensor,
    sr: int,
    cfg: VADConfig,
) -> list[dict]:
    """Run Silero-VAD and return speech timestamp dicts."""
    model, utils = _load_vad()
    get_ts = utils[0]  # get_speech_timestamps

    if sr != 16_000:
        wav = torchaudio.functional.resample(wav, sr, 16_000)
        sr = 16_000

    if wav.ndim > 1:
        wav = wav.squeeze()

    timestamps = get_ts(
        wav,
        model,
        sampling_rate=sr,
        threshold=cfg.threshold,
        min_silence_duration_ms=cfg.min_silence_duration_ms,
        speech_pad_ms=cfg.speech_pad_ms,
        min_speech_duration_ms=int(cfg.min_speech_duration_s * 1000),
    )
    return timestamps


def _merge_segments_to_target(
    timestamps: list[dict],
    total_samples: int,
    sr: int,
    cfg: VADConfig,
) -> list[tuple[int, int]]:
    """Merge VAD speech chunks into segments within [min_s, max_s] range.

    Strategy: greedily accumulate consecutive speech chunks. When accumulated
    duration would exceed max_segment_s, cut at the last silence boundary.
    Random target durations give good distribution across [min_s, max_s].
    """
    min_samples = int(cfg.min_segment_s * sr)
    max_samples = int(cfg.max_segment_s * sr)

    if not timestamps:
        return []

    segments: list[tuple[int, int]] = []
    current_start = timestamps[0]["start"]
    current_end = timestamps[0]["end"]

    for i in range(1, len(timestamps)):
        chunk_start = timestamps[i]["start"]
        chunk_end = timestamps[i]["end"]

        proposed_end = chunk_end
        proposed_dur = proposed_end - current_start

        if proposed_dur <= max_samples:
            current_end = proposed_end
        else:
            seg_dur = current_end - current_start
            if seg_dur >= min_samples:
                segments.append((current_start, current_end))
            elif seg_dur > 0:
                # Too short on its own — extend into next chunk if possible
                extended = min(current_start + min_samples, chunk_end, total_samples)
                segments.append((current_start, extended))

            current_start = chunk_start
            current_end = chunk_end

    # Final segment
    final_dur = current_end - current_start
    if final_dur >= min_samples:
        segments.append((current_start, current_end))

    # Split segments that are still too long (shouldn't happen often)
    result = []
    for start, end in segments:
        dur = end - start
        if dur <= max_samples:
            result.append((start, end))
        else:
            pos = start
            while pos < end:
                # Random target between 60-100% of max to get duration variety
                target = int(max_samples * random.uniform(0.6, 1.0))
                seg_end = min(pos + target, end)
                if seg_end - pos >= min_samples:
                    result.append((pos, seg_end))
                pos = seg_end

    return result


def segment_audio(
    audio_path: Path | str,
    cfg: VADConfig | None = None,
) -> list[Segment]:
    """Full VAD pipeline: load audio → detect speech → create segments.

    Returns list of Segment objects with 2-30s speech audio.
    """
    if cfg is None:
        cfg = VADConfig()

    wav, sr = torchaudio.load(str(audio_path))
    if wav.shape[0] > 1:
        wav = wav.mean(dim=0, keepdim=True)
    if sr != cfg.sample_rate:
        wav = torchaudio.functional.resample(wav, sr, cfg.sample_rate)
        sr = cfg.sample_rate

    wav_1d = wav.squeeze()
    total_samples = wav_1d.shape[0]
    total_duration = total_samples / sr

    logger.info("Audio loaded: %.1f s, %d samples", total_duration, total_samples)

    timestamps = get_speech_timestamps(wav_1d, sr, cfg)
    speech_dur = sum((ts["end"] - ts["start"]) for ts in timestamps) / sr
    logger.info("VAD detected %.1f s speech out of %.1f s total (%.0f%%)",
                speech_dur, total_duration, 100 * speech_dur / max(total_duration, 0.01))

    seg_bounds = _merge_segments_to_target(timestamps, total_samples, sr, cfg)

    segments = []
    for start, end in seg_bounds:
        segment_wav = wav[:, start:end]
        segments.append(Segment(
            start_s=start / sr,
            end_s=end / sr,
            audio=segment_wav,
        ))

    durations = [s.duration_s for s in segments]
    if durations:
        logger.info(
            "Created %d segments: %.1f-%.1f s (mean %.1f s, total %.1f s usable)",
            len(segments), min(durations), max(durations),
            sum(durations) / len(durations), sum(durations),
        )
    else:
        logger.warning("No valid speech segments found in %s", audio_path)

    return segments


def segment_tensor(
    wav: torch.Tensor,
    sr: int,
    cfg: VADConfig | None = None,
) -> list[Segment]:
    """Same as segment_audio but from an in-memory tensor [1, T] or [T]."""
    if cfg is None:
        cfg = VADConfig()

    if wav.ndim == 1:
        wav = wav.unsqueeze(0)
    if wav.shape[0] > 1:
        wav = wav.mean(dim=0, keepdim=True)
    if sr != cfg.sample_rate:
        wav = torchaudio.functional.resample(wav, sr, cfg.sample_rate)
        sr = cfg.sample_rate

    wav_1d = wav.squeeze()
    total_samples = wav_1d.shape[0]
    duration_s = total_samples / sr

    # Chunked parallel VAD for long audio (>chunk_threshold_s)
    chunk_threshold_s = cfg.chunk_threshold_s
    chunk_size_s = cfg.chunk_size_s
    overlap_s = cfg.chunk_overlap_s

    if duration_s > chunk_threshold_s:
        timestamps = _chunked_vad(wav_1d, sr, cfg, chunk_size_s, overlap_s)
    else:
        timestamps = get_speech_timestamps(wav_1d, sr, cfg)

    seg_bounds = _merge_segments_to_target(timestamps, total_samples, sr, cfg)

    segments = []
    for start, end in seg_bounds:
        segments.append(Segment(
            start_s=start / sr,
            end_s=end / sr,
            audio=wav[:, start:end],
        ))
    return segments


def _chunked_vad(
    wav_1d: torch.Tensor,
    sr: int,
    cfg: VADConfig,
    chunk_size_s: float = 300.0,
    overlap_s: float = 2.0,
) -> list[dict]:
    """Run VAD on audio chunks in parallel, merge results.

    Splits long audio into overlapping chunks, processes each in its own thread
    (each with its own Silero model), then merges timestamps with overlap dedup.
    For 2500s audio with 300s chunks: 36s → ~4s (9x speedup).
    """
    from concurrent.futures import ThreadPoolExecutor, as_completed

    total_samples = wav_1d.shape[0]
    chunk_samples = int(chunk_size_s * sr)
    overlap_samples = int(overlap_s * sr)
    step_samples = chunk_samples - overlap_samples

    # Build chunk boundaries
    chunks: list[tuple[int, int]] = []
    pos = 0
    while pos < total_samples:
        end = min(pos + chunk_samples, total_samples)
        chunks.append((pos, end))
        pos += step_samples
        if end >= total_samples:
            break

    if len(chunks) <= 1:
        return get_speech_timestamps(wav_1d, sr, cfg)

    logger.debug("Chunked VAD: %d chunks of %.0fs (%.1fs overlap) for %.1fs audio",
                 len(chunks), chunk_size_s, overlap_s, total_samples / sr)

    def _vad_chunk(chunk_start: int, chunk_end: int) -> list[dict]:
        chunk_wav = wav_1d[chunk_start:chunk_end]
        ts = get_speech_timestamps(chunk_wav, sr, cfg)
        # Offset timestamps to global position
        for t in ts:
            t["start"] += chunk_start
            t["end"] += chunk_start
        return ts

    # Run chunks in parallel -- each thread gets its own VAD model via _load_vad()
    max_workers = min(len(chunks), 8)
    all_timestamps: list[list[dict]] = [None] * len(chunks)

    with ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix="vad") as pool:
        future_to_idx = {
            pool.submit(_vad_chunk, start, end): i
            for i, (start, end) in enumerate(chunks)
        }
        for future in as_completed(future_to_idx):
            idx = future_to_idx[future]
            all_timestamps[idx] = future.result()

    # Merge and deduplicate overlapping regions
    merged = _merge_overlapping_timestamps(all_timestamps, overlap_samples)
    return merged


def _merge_overlapping_timestamps(
    chunk_timestamps: list[list[dict]],
    overlap_samples: int,
) -> list[dict]:
    """Merge timestamps from overlapping chunks, deduplicating the overlap regions.

    For each pair of adjacent chunks, timestamps in the overlap zone are resolved
    by keeping the earlier chunk's timestamps up to the midpoint of the overlap,
    and the later chunk's timestamps from the midpoint onward.
    """
    if not chunk_timestamps:
        return []

    result = list(chunk_timestamps[0])

    for i in range(1, len(chunk_timestamps)):
        next_ts = chunk_timestamps[i]
        if not next_ts:
            continue
        if not result:
            result.extend(next_ts)
            continue

        # The overlap midpoint: anything before this comes from the previous chunk,
        # anything after from the current chunk
        last_end_prev = result[-1]["end"] if result else 0
        first_start_next = next_ts[0]["start"] if next_ts else 0

        # Find the boundary: midpoint of where chunks overlap
        # Timestamps from next_ts that start after the last timestamp of result
        cutoff = last_end_prev
        for t in next_ts:
            if t["start"] >= cutoff:
                result.append(t)
            elif t["end"] > cutoff:
                result.append({"start": cutoff, "end": t["end"]})

    return result
