"""
Audio polishing pipeline: STEP 1 length-split -> STEP 2 boundary-trim -> STEP 3 silence-pad -> STEP 4 encode.
All trim metadata is tracked for downstream use.
"""
from __future__ import annotations

import base64
import io
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional

import numpy as np
import soundfile as sf

from .config import (
    MIN_SEGMENT_DURATION_S, MAX_SEGMENT_DURATION_S, PREFERRED_MAX_DURATION_S,
    SPLIT_SEARCH_START_S, FORCE_CUT_RANGE, BOUNDARY_CHECK_MS,
    SILENCE_PAD_MS, BOUNDARY_TRIM_MAX_PCT,
)

logger = logging.getLogger(__name__)


@dataclass
class TrimMetadata:
    original_file: str
    original_duration_ms: float
    original_start_ms: float = 0.0
    original_end_ms: float = 0.0
    trimmed_start_ms: float = 0.0
    trimmed_end_ms: float = 0.0
    leading_pad_ms: float = SILENCE_PAD_MS
    trailing_pad_ms: float = SILENCE_PAD_MS
    final_duration_ms: float = 0.0
    was_split: bool = False
    split_index: int = 0
    abrupt_start: bool = False
    abrupt_end: bool = False
    discarded: bool = False
    discard_reason: str = ""


@dataclass
class PolishedSegment:
    audio: np.ndarray
    sr: int
    trim_meta: TrimMetadata
    flac_bytes: bytes = b""
    base64_audio: str = ""


def _compute_rms_profile(audio: np.ndarray, sr: int, frame_ms: int = 10) -> np.ndarray:
    """Compute RMS energy profile with `frame_ms` resolution."""
    frame_len = int(sr * frame_ms / 1000)
    n_frames = len(audio) // frame_len
    if n_frames == 0:
        return np.array([np.sqrt(np.mean(audio ** 2))])
    frames = audio[:n_frames * frame_len].reshape(n_frames, frame_len)
    return np.sqrt(np.mean(frames ** 2, axis=1))


def _find_silence_valleys(rms: np.ndarray, frame_ms: int = 10,
                          threshold_percentile: float = 15.0) -> list[int]:
    """Find frame indices where RMS is below threshold (silence valleys)."""
    threshold = np.percentile(rms, threshold_percentile)
    valleys = np.where(rms < threshold)[0]
    return valleys.tolist()


def _frame_to_sample(frame_idx: int, sr: int, frame_ms: int = 10) -> int:
    return frame_idx * int(sr * frame_ms / 1000)


def _sample_to_ms(sample: int, sr: int) -> float:
    return (sample / sr) * 1000


def step1_length_split(audio: np.ndarray, sr: int,
                        original_file: str) -> list[tuple[np.ndarray, TrimMetadata]]:
    """Split segments >10s at silence valleys. Discard <2s."""
    duration_s = len(audio) / sr

    if duration_s < MIN_SEGMENT_DURATION_S:
        meta = TrimMetadata(
            original_file=original_file,
            original_duration_ms=duration_s * 1000,
            discarded=True,
            discard_reason=f"Too short: {duration_s:.1f}s < {MIN_SEGMENT_DURATION_S}s",
        )
        return [(audio, meta)]

    if duration_s <= PREFERRED_MAX_DURATION_S:
        meta = TrimMetadata(
            original_file=original_file,
            original_duration_ms=duration_s * 1000,
            original_end_ms=duration_s * 1000,
        )
        return [(audio, meta)]

    # Need to split
    rms = _compute_rms_profile(audio, sr)
    valleys = _find_silence_valleys(rms)
    frame_ms = 10

    results = []
    current_start_sample = 0
    split_idx = 0

    while current_start_sample < len(audio):
        remaining = len(audio) - current_start_sample
        remaining_s = remaining / sr

        if remaining_s <= MAX_SEGMENT_DURATION_S:
            chunk = audio[current_start_sample:]
            chunk_dur = len(chunk) / sr
            if chunk_dur >= MIN_SEGMENT_DURATION_S:
                meta = TrimMetadata(
                    original_file=original_file,
                    original_duration_ms=(len(audio) / sr) * 1000,
                    original_start_ms=_sample_to_ms(current_start_sample, sr),
                    original_end_ms=_sample_to_ms(len(audio), sr),
                    was_split=split_idx > 0,
                    split_index=split_idx,
                )
                results.append((chunk, meta))
            break

        # Find first silence valley after SPLIT_SEARCH_START_S
        search_start_frame = int((_sample_to_ms(current_start_sample, sr) + SPLIT_SEARCH_START_S * 1000) / frame_ms)
        search_end_frame = int((_sample_to_ms(current_start_sample, sr) + MAX_SEGMENT_DURATION_S * 1000) / frame_ms)
        search_end_frame = min(search_end_frame, len(rms))

        cut_frame = None
        for v in valleys:
            if search_start_frame <= v <= search_end_frame:
                cut_frame = v
                break

        if cut_frame is None:
            # Force-cut at lowest energy point in range
            if search_start_frame < search_end_frame and search_end_frame <= len(rms):
                force_range = rms[search_start_frame:search_end_frame]
                if len(force_range) > 0:
                    cut_frame = search_start_frame + np.argmin(force_range)

        if cut_frame is None:
            cut_frame = search_end_frame

        cut_sample = _frame_to_sample(cut_frame, sr)
        cut_sample = min(cut_sample, len(audio))

        chunk = audio[current_start_sample:cut_sample]
        chunk_dur = len(chunk) / sr

        if chunk_dur >= MIN_SEGMENT_DURATION_S:
            meta = TrimMetadata(
                original_file=original_file,
                original_duration_ms=(len(audio) / sr) * 1000,
                original_start_ms=_sample_to_ms(current_start_sample, sr),
                original_end_ms=_sample_to_ms(cut_sample, sr),
                was_split=True,
                split_index=split_idx,
            )
            results.append((chunk, meta))
            split_idx += 1

        current_start_sample = cut_sample

    return results if results else [(audio, TrimMetadata(
        original_file=original_file,
        original_duration_ms=(len(audio) / sr) * 1000,
        discarded=True,
        discard_reason="No valid split points found",
    ))]


def step2_boundary_trim(audio: np.ndarray, sr: int,
                         meta: TrimMetadata) -> tuple[np.ndarray, TrimMetadata]:
    """Trim dirty edges: check first/last 50ms RMS, scan for silence valley to trim."""
    if meta.discarded:
        return audio, meta

    check_samples = int(sr * BOUNDARY_CHECK_MS / 1000)
    max_trim_samples = int(len(audio) * BOUNDARY_TRIM_MAX_PCT)
    rms = _compute_rms_profile(audio, sr, frame_ms=5)
    median_rms = np.median(rms) if len(rms) > 0 else 0
    # Threshold: if edge RMS > 60% of median, consider it "dirty"
    dirty_threshold = median_rms * 0.6

    trim_start = 0
    trim_end = len(audio)

    # Check start
    start_chunk = audio[:min(check_samples, len(audio))]
    start_rms = np.sqrt(np.mean(start_chunk ** 2)) if len(start_chunk) > 0 else 0
    if start_rms > dirty_threshold and len(audio) > check_samples * 2:
        frame_ms = 5
        search_limit = min(max_trim_samples, len(rms))
        valleys = _find_silence_valleys(rms[:search_limit], frame_ms=frame_ms, threshold_percentile=25)
        if valleys:
            trim_frame = valleys[0]
            trim_start = _frame_to_sample(trim_frame, sr, frame_ms=frame_ms)
            meta.trimmed_start_ms = _sample_to_ms(trim_start, sr)
        else:
            meta.abrupt_start = True

    # Check end
    end_chunk = audio[max(0, len(audio) - check_samples):]
    end_rms = np.sqrt(np.mean(end_chunk ** 2)) if len(end_chunk) > 0 else 0
    if end_rms > dirty_threshold and len(audio) > check_samples * 2:
        frame_ms = 5
        search_start = max(0, len(rms) - int(max_trim_samples / (sr * frame_ms / 1000)))
        valleys_end = _find_silence_valleys(rms[search_start:], frame_ms=frame_ms, threshold_percentile=25)
        if valleys_end:
            trim_frame = search_start + valleys_end[-1]
            trim_end = _frame_to_sample(trim_frame, sr, frame_ms=frame_ms)
            meta.trimmed_end_ms = _sample_to_ms(len(audio) - trim_end, sr)
        else:
            meta.abrupt_end = True

    trimmed = audio[trim_start:trim_end]

    # Post-trim length check
    if len(trimmed) / sr < MIN_SEGMENT_DURATION_S:
        meta.discarded = True
        meta.discard_reason = f"Post-trim too short: {len(trimmed)/sr:.1f}s"
        return trimmed, meta

    return trimmed, meta


def step3_silence_pad(audio: np.ndarray, sr: int,
                       meta: TrimMetadata) -> tuple[np.ndarray, TrimMetadata]:
    """Prepend and append 150ms silence."""
    if meta.discarded:
        return audio, meta

    pad_samples = int(sr * SILENCE_PAD_MS / 1000)
    silence = np.zeros(pad_samples, dtype=audio.dtype)
    padded = np.concatenate([silence, audio, silence])

    meta.leading_pad_ms = SILENCE_PAD_MS
    meta.trailing_pad_ms = SILENCE_PAD_MS
    meta.final_duration_ms = (len(padded) / sr) * 1000

    # Post-pad length check
    if len(padded) / sr < MIN_SEGMENT_DURATION_S:
        meta.discarded = True
        meta.discard_reason = f"Post-pad too short: {len(padded)/sr:.1f}s"

    return padded, meta


def step4_encode(audio: np.ndarray, sr: int) -> tuple[bytes, str]:
    """Encode to FLAC bytes and base64 string."""
    buf = io.BytesIO()
    sf.write(buf, audio, sr, format="FLAC")
    flac_bytes = buf.getvalue()
    b64 = base64.b64encode(flac_bytes).decode("ascii")
    return flac_bytes, b64


def polish_segment(audio_path: Path) -> list[PolishedSegment]:
    """Full pipeline: load -> split -> trim -> pad -> encode. Returns list (may be >1 if split)."""
    audio, sr = sf.read(str(audio_path), dtype="float32")

    # Mono downmix if stereo
    if audio.ndim > 1:
        audio = audio.mean(axis=1)

    # STEP 1: Length split
    splits = step1_length_split(audio, sr, original_file=audio_path.name)

    results = []
    for chunk, meta in splits:
        if meta.discarded:
            logger.debug(f"Discarded {audio_path.name}: {meta.discard_reason}")
            results.append(PolishedSegment(audio=chunk, sr=sr, trim_meta=meta))
            continue

        # STEP 2: Boundary trim
        trimmed, meta = step2_boundary_trim(chunk, sr, meta)
        if meta.discarded:
            logger.debug(f"Discarded after trim {audio_path.name}: {meta.discard_reason}")
            results.append(PolishedSegment(audio=trimmed, sr=sr, trim_meta=meta))
            continue

        # STEP 3: Silence pad
        padded, meta = step3_silence_pad(trimmed, sr, meta)
        if meta.discarded:
            results.append(PolishedSegment(audio=padded, sr=sr, trim_meta=meta))
            continue

        # STEP 4: Encode
        flac_bytes, b64 = step4_encode(padded, sr)

        results.append(PolishedSegment(
            audio=padded, sr=sr, trim_meta=meta,
            flac_bytes=flac_bytes, base64_audio=b64,
        ))

    return results


def _safe_polish_segment(path: Path) -> list[PolishedSegment]:
    """Wrapper that catches errors for thread pool execution."""
    try:
        return polish_segment(path)
    except Exception as e:
        logger.error(f"Failed to polish {path.name}: {e}")
        return [PolishedSegment(
            audio=np.array([]), sr=16000,
            trim_meta=TrimMetadata(
                original_file=path.name,
                original_duration_ms=0,
                discarded=True,
                discard_reason=f"Polish error: {e}",
            ),
        )]


def polish_all_segments(segment_paths: list[Path]) -> list[PolishedSegment]:
    """Polish all segments in parallel. soundfile/numpy release the GIL so threads give real speedup."""
    from concurrent.futures import ThreadPoolExecutor
    import os

    if not segment_paths:
        return []

    max_workers = min(os.cpu_count() or 4, len(segment_paths), 16)
    all_polished = []

    with ThreadPoolExecutor(max_workers=max_workers) as pool:
        for results in pool.map(_safe_polish_segment, segment_paths):
            all_polished.extend(results)

    return all_polished
