"""
Audio Polisher - Segment boundary cleanup for transcription quality.

Two-phase polishing:
  Phase 1 (original): Remove CUT SPEECH ARTIFACTS from VAD boundaries
    - Detects burst -> gap -> speech pattern and removes the burst
    - Silence at boundaries preserved as natural padding
  Phase 2 (new): TIGHTEN boundaries to actual speech onset/offset
    - Energy-based detection of where speech actually starts/ends
    - Cuts to those points with configurable silence margin (~50ms)
    - Ensures clean, crisp boundaries for TTS training data
    - Prevents segments starting/ending mid-voice-energy

Combined result: artifact-free + precisely-bounded audio segments.
No ML models - pure signal processing, runs in <10ms/segment.
"""
import os
import numpy as np
import soundfile as sf
from dataclasses import dataclass
from typing import Optional, Tuple, List
from pathlib import Path


@dataclass
class PolishResult:
    """Result of audio polishing for a single segment."""
    input_path: str
    output_path: str
    was_modified: bool
    start_trimmed_ms: float = 0.0
    end_trimmed_ms: float = 0.0
    start_quality: str = "clean"
    end_quality: str = "clean"
    original_duration_ms: float = 0.0
    polished_duration_ms: float = 0.0
    snr_db: float = 0.0
    rms_db: float = 0.0
    peak_db: float = 0.0
    is_clipped: bool = False
    volume_adjusted: bool = False
    volume_gain_db: float = 0.0

    @property
    def total_trimmed_ms(self):
        return self.start_trimmed_ms + self.end_trimmed_ms

    def summary(self):
        name = Path(self.input_path).name
        if not self.was_modified:
            return (f"  {name}: SKIP | {self.original_duration_ms:.0f}ms | "
                    f"SNR:{self.snr_db:.1f}dB | start:{self.start_quality} "
                    f"end:{self.end_quality}")
        parts = []
        if self.start_trimmed_ms > 0:
            parts.append(
                f"start -{self.start_trimmed_ms:.0f}ms ({self.start_quality})")
        if self.end_trimmed_ms > 0:
            parts.append(
                f"end -{self.end_trimmed_ms:.0f}ms ({self.end_quality})")
        if self.volume_adjusted:
            parts.append(f"gain +{self.volume_gain_db:.1f}dB")
        changes = " | ".join(parts)
        return (f"  {name}: POLISHED | {self.original_duration_ms:.0f}ms -> "
                f"{self.polished_duration_ms:.0f}ms | {changes} | "
                f"SNR:{self.snr_db:.1f}dB")


class AudioPolisher:
    """
    Two-phase audio segment polisher for transcription quality.

    Phase 1: Conservative artifact removal (burst-gap-speech pattern detection)
    Phase 2: Energy-based boundary tightening (speech onset/offset + silence margin)
    """

    def __init__(
        self,
        frame_ms=10.0,
        hop_ms=5.0,
        max_start_trim_ms=200.0,
        artifact_search_ms=150.0,
        min_gap_ms=15.0,
        max_end_trim_ms=150.0,
        min_burst_ms=10.0,
        max_burst_ms=100.0,
        energy_gap_factor=0.35,
        zcr_artifact_threshold=0.15,
        min_dynamic_range_db=6.0,
        min_silence_pad_ms=30.0,
        min_remaining_ms=300.0,
        min_trim_ms=10.0,
        fade_in_ms=5.0,
        fade_out_ms=10.0,
        target_rms_db=-18.0,
        snr_boost_threshold_db=12.0,
        max_gain_db=12.0,
        clipping_threshold=0.99,
        # Phase 2: boundary tightening params
        tighten_boundaries=True,
        silence_margin_ms=50.0,
        output_format="flac",
    ):
        self.frame_ms = frame_ms
        self.hop_ms = hop_ms
        self.max_start_trim_ms = max_start_trim_ms
        self.artifact_search_ms = artifact_search_ms
        self.min_gap_ms = min_gap_ms
        self.max_end_trim_ms = max_end_trim_ms
        self.min_burst_ms = min_burst_ms
        self.max_burst_ms = max_burst_ms
        self.energy_gap_factor = energy_gap_factor
        self.zcr_artifact_threshold = zcr_artifact_threshold
        self.min_dynamic_range_db = min_dynamic_range_db
        self.min_silence_pad_ms = min_silence_pad_ms
        self.min_remaining_ms = min_remaining_ms
        self.min_trim_ms = min_trim_ms
        self.fade_in_ms = fade_in_ms
        self.fade_out_ms = fade_out_ms
        self.target_rms_db = target_rms_db
        self.snr_boost_threshold_db = snr_boost_threshold_db
        self.max_gain_db = max_gain_db
        self.clipping_threshold = clipping_threshold
        self.tighten_boundaries_enabled = tighten_boundaries
        self.silence_margin_ms = silence_margin_ms
        self.output_format = output_format

    def _compute_frames(self, audio, sr):
        """Compute per-frame energy (dB) and zero-crossing rate."""
        frame_samples = int(self.frame_ms * sr / 1000)
        hop_samples = int(self.hop_ms * sr / 1000)
        n_frames = max(1, (len(audio) - frame_samples) // hop_samples + 1)
        ste_db = np.zeros(n_frames)
        zcr = np.zeros(n_frames)
        frame_times = np.zeros(n_frames)
        for i in range(n_frames):
            start = i * hop_samples
            end = min(start + frame_samples, len(audio))
            frame = audio[start:end]
            rms = np.sqrt(np.mean(frame ** 2))
            ste_db[i] = 20 * np.log10(rms + 1e-10)
            if len(frame) > 1:
                zcr[i] = (np.sum(np.abs(np.diff(np.sign(frame))))
                          / (2 * len(frame)))
            frame_times[i] = (start + frame_samples / 2) / sr
        return ste_db, zcr, frame_times

    def _detect_leading_artifact(self, ste_db, zcr, frame_times):
        """
        Detect CUT SPEECH artifact at segment start.

        ONLY acts when: burst (cut speech) -> gap (silence) -> real speech.
        Trims at gap START = removes burst, keeps gap as silence padding.

        Does NOT trim if segment starts with silence (that's good padding).
        Does NOT trim if segment starts with continuous speech (natural start).
        """
        search_limit = self.artifact_search_ms / 1000
        si = np.where(frame_times < search_limit)[0]
        if len(si) < 3:
            return None, "clean"

        nf = np.percentile(ste_db, 15)
        sl = np.percentile(ste_db, 85)
        dr = sl - nf
        if dr < self.min_dynamic_range_db:
            return None, "clean"

        gap_thresh = nf + self.energy_gap_factor * dr

        n_check = min(3, len(si))
        first_frames_energy = ste_db[si[:n_check]]
        if np.all(first_frames_energy < gap_thresh):
            return None, "clean"

        below = ste_db[si] < gap_thresh

        gaps = []
        in_gap = False
        gs = 0
        for i in range(len(below)):
            if below[i] and not in_gap:
                in_gap = True
                gs = i
            elif not below[i] and in_gap:
                in_gap = False
                dur = (frame_times[si[i]] - frame_times[si[gs]]) * 1000
                if dur >= self.min_gap_ms:
                    gaps.append((gs, i, dur))
        if in_gap:
            dur = (frame_times[si[-1]] - frame_times[si[gs]]) * 1000
            if dur >= self.min_gap_ms:
                gaps.append((gs, len(below), dur))

        if not gaps:
            return None, "clean"

        g = gaps[0]
        if g[0] == 0:
            return None, "clean"
        if g[1] >= len(si):
            return None, "marginal_kept"

        pre = si[:g[0]]
        mzcr = np.mean(zcr[pre]) if len(pre) > 0 else 0

        gap_start_frame = si[g[0]]
        tt = frame_times[gap_start_frame]

        if tt * 1000 > self.max_start_trim_ms or tt * 1000 < self.min_trim_ms:
            return None, "marginal_kept" if tt * 1000 > self.max_start_trim_ms else "clean"

        gap_ms = g[2]
        if gap_ms < self.min_silence_pad_ms:
            if mzcr <= self.zcr_artifact_threshold:
                return None, "marginal_kept"

        if mzcr > self.zcr_artifact_threshold:
            return tt, "artifact_trimmed"
        if g[2] >= self.min_gap_ms * 1.5:
            return tt, "artifact_trimmed"

        return None, "marginal_kept"

    def _detect_trailing_artifact(self, ste_db, frame_times, total_dur):
        """
        Detect CUT SPEECH artifact at segment end.

        ONLY acts when: sustained speech -> silence gap -> isolated burst at end.
        """
        nf = np.percentile(ste_db, 15)
        sl = np.percentile(ste_db, 85)
        dr = sl - nf
        if dr < self.min_dynamic_range_db:
            return None, "clean", False

        thresh = nf + self.energy_gap_factor * dr
        above = ste_db > thresh
        n = len(ste_db)

        n_check = min(3, n)
        if np.all(~above[-n_check:]):
            return None, "clean", False

        ai = np.where(above)[0]
        if len(ai) == 0:
            return None, "clean", False

        last_above = ai[-1]

        lookback = min(10, last_above)
        if lookback >= 3:
            preceding = above[last_above - lookback:last_above]
            sustained = np.mean(preceding)

            if sustained < 0.3:
                burst_start = last_above
                for bi in range(last_above - 1, -1, -1):
                    if above[bi]:
                        burst_start = bi
                    else:
                        break

                burst_dur = (frame_times[last_above] - frame_times[burst_start]) * 1000

                if self.min_burst_ms <= burst_dur <= self.max_burst_ms:
                    tt = frame_times[burst_start]
                    actual_trim = (total_dur - tt) * 1000

                    if actual_trim <= self.max_end_trim_ms and actual_trim >= self.min_trim_ms:
                        return tt, "artifact_trimmed", False

        return None, "clean", True  # needs_fade_out=True

    # === Phase 2: Energy-based boundary tightening ===

    def _find_speech_onset(self, ste_db, frame_times, noise_floor_db):
        """
        Find precise speech onset: first sustained energy rise above noise floor.

        Uses hysteresis: speech starts when energy exceeds threshold for 3+
        consecutive frames (~15ms at 5ms hop). Prevents transient noise spikes
        from being misidentified as speech onset.

        Returns: onset time in seconds, or None if no clear onset found.
        """
        thresh = noise_floor_db + 6.0  # 6dB above noise = factor of 2 amplitude
        consec_needed = 3

        consec = 0
        for i, e in enumerate(ste_db):
            if e > thresh:
                consec += 1
                if consec >= consec_needed:
                    onset_frame = i - consec_needed + 1
                    return frame_times[onset_frame]
            else:
                consec = 0
        return None

    def _find_speech_offset(self, ste_db, frame_times, noise_floor_db):
        """
        Find precise speech offset: last sustained energy above noise floor.

        Mirrors onset detection, scanning backward.
        Returns: offset time in seconds, or None if no clear offset found.
        """
        thresh = noise_floor_db + 6.0
        consec_needed = 3

        consec = 0
        for i in range(len(ste_db) - 1, -1, -1):
            if ste_db[i] > thresh:
                consec += 1
                if consec >= consec_needed:
                    offset_frame = min(i + consec_needed - 1, len(frame_times) - 1)
                    return frame_times[offset_frame]
            else:
                consec = 0
        return None

    def _tighten_to_speech(self, audio, sr, ste_db, frame_times):
        """
        Tighten segment boundaries to actual speech onset/offset.

        After artifact removal, precisely locates where speech energy begins
        and ends, then cuts to those points with silence_margin_ms padding.
        Ensures segments don't start/end mid-voice-energy.

        Returns: (tightened_audio, start_trim_ms, end_trim_ms) or
                 (None, 0, 0) if no tightening needed
        """
        total_dur = len(audio) / sr
        noise_floor = np.percentile(ste_db, 10)

        onset = self._find_speech_onset(ste_db, frame_times, noise_floor)
        offset = self._find_speech_offset(ste_db, frame_times, noise_floor)

        if onset is None or offset is None:
            return None, 0.0, 0.0

        margin_sec = self.silence_margin_ms / 1000.0
        new_start = max(0.0, onset - margin_sec)
        new_end = min(total_dur, offset + margin_sec)

        start_trim = new_start * 1000
        end_trim = (total_dur - new_end) * 1000

        # Only tighten if we'd remove >30ms from at least one end
        if start_trim < 30.0 and end_trim < 30.0:
            return None, 0.0, 0.0

        new_dur_ms = (new_end - new_start) * 1000
        if new_dur_ms < self.min_remaining_ms:
            return None, 0.0, 0.0

        ss = int(new_start * sr)
        es = int(new_end * sr)
        return audio[ss:es], start_trim, end_trim

    def _measure_snr(self, ste_db):
        """Estimate SNR from energy distribution."""
        return np.percentile(ste_db, 90) - np.percentile(ste_db, 10)

    def _apply_fade(self, audio, sr, fade_in=False, fade_out=False):
        """Apply smooth fade-in/fade-out to prevent click artifacts."""
        out = audio.copy()
        if fade_in and self.fade_in_ms > 0:
            n = int(self.fade_in_ms * sr / 1000)
            if n > 0 and n < len(out):
                out[:n] *= np.linspace(0, 1, n)
        if fade_out and self.fade_out_ms > 0:
            n = int(self.fade_out_ms * sr / 1000)
            if n > 0 and n < len(out):
                out[-n:] *= np.linspace(1, 0, n)
        return out

    def polish(self, audio_path, output_path=None, output_dir=None):
        """
        Polish a single audio segment (two-phase).

        Phase 1: Remove cut speech artifacts (burst-gap-speech pattern)
        Phase 2: Tighten boundaries to actual speech onset/offset + silence margin

        Already-clean segments pass through unchanged (was_modified=False).
        """
        audio, sr = sf.read(audio_path)
        if len(audio.shape) > 1:
            audio = audio.mean(axis=1)
        orig_ms = len(audio) / sr * 1000

        rms_val = np.sqrt(np.mean(audio ** 2))
        peak = np.max(np.abs(audio))
        rms_db = 20 * np.log10(rms_val + 1e-10)
        peak_db = 20 * np.log10(peak + 1e-10)
        is_clipped = peak >= self.clipping_threshold

        if orig_ms < self.min_remaining_ms:
            return PolishResult(
                input_path=audio_path, output_path=audio_path,
                was_modified=False, original_duration_ms=orig_ms,
                polished_duration_ms=orig_ms, rms_db=rms_db,
                peak_db=peak_db, is_clipped=is_clipped)

        ste_db, zcr, ft = self._compute_frames(audio, sr)
        snr_db = self._measure_snr(ste_db)

        # === Phase 1: Artifact removal ===
        s_sec, s_q = self._detect_leading_artifact(ste_db, zcr, ft)
        e_sec, e_q, needs_fade_out = self._detect_trailing_artifact(
            ste_db, ft, len(audio) / sr)

        s_ms = (s_sec * 1000) if s_sec is not None else 0.0
        e_ms = ((len(audio) / sr - e_sec) * 1000) if e_sec is not None else 0.0

        # Safety: don't shrink below minimum
        if orig_ms - s_ms - e_ms < self.min_remaining_ms:
            s_sec, e_sec = None, None
            s_ms, e_ms = 0.0, 0.0
            s_q = "marginal_kept" if s_q != "clean" else "clean"
            e_q = "marginal_kept" if e_q != "clean" else "clean"
            needs_fade_out = False

        # Volume normalization for quiet, low-SNR segments
        vol_adj = False
        vol_gain = 0.0
        if snr_db < self.snr_boost_threshold_db and rms_db < self.target_rms_db:
            vol_gain = min(self.target_rms_db - rms_db, self.max_gain_db)
            if vol_gain > 1.0:
                vol_adj = True
            else:
                vol_gain = 0.0

        # Apply Phase 1 trim
        pol = audio.copy()
        ss = int(s_sec * sr) if s_sec is not None else 0
        es = int(e_sec * sr) if e_sec is not None else len(audio)
        pol = pol[ss:es]

        phase1_modified = (s_sec is not None or e_sec is not None)

        # === Phase 2: Tighten boundaries to speech onset/offset ===
        if self.tighten_boundaries_enabled and len(pol) / sr * 1000 >= self.min_remaining_ms:
            ste2, _, ft2 = self._compute_frames(pol, sr)
            tightened, t_start, t_end = self._tighten_to_speech(
                pol, sr, ste2, ft2
            )
            if tightened is not None:
                pol = tightened
                s_ms += t_start
                e_ms += t_end
                if s_q == "clean" and t_start > 30:
                    s_q = "boundary_tightened"
                if e_q == "clean" and t_end > 30:
                    e_q = "boundary_tightened"

        # Apply fades
        pol = self._apply_fade(
            pol, sr,
            fade_in=(s_ms > 0),
            fade_out=needs_fade_out
        )

        if vol_adj:
            g = 10 ** (vol_gain / 20)
            pol = pol * g
            mx = np.max(np.abs(pol))
            if mx > 0.99:
                pol = pol * (0.99 / mx)

        needs_mod = (phase1_modified or s_ms > 0 or e_ms > 0 or vol_adj)

        if not needs_mod:
            return PolishResult(
                input_path=audio_path, output_path=audio_path,
                was_modified=False, start_quality=s_q, end_quality=e_q,
                original_duration_ms=orig_ms, polished_duration_ms=orig_ms,
                snr_db=snr_db, rms_db=rms_db, peak_db=peak_db,
                is_clipped=is_clipped)

        pol_ms = len(pol) / sr * 1000

        if output_path is None:
            if output_dir:
                os.makedirs(output_dir, exist_ok=True)
                output_path = os.path.join(output_dir, Path(audio_path).name)
            else:
                stem = Path(audio_path).stem
                ext = Path(audio_path).suffix or f".{self.output_format}"
                output_path = str(
                    Path(audio_path).parent / f"{stem}_polished{ext}")

        sf.write(output_path, pol, sr)

        return PolishResult(
            input_path=audio_path, output_path=output_path,
            was_modified=True, start_trimmed_ms=s_ms, end_trimmed_ms=e_ms,
            start_quality=s_q, end_quality=e_q,
            original_duration_ms=orig_ms, polished_duration_ms=pol_ms,
            snr_db=snr_db, rms_db=rms_db, peak_db=peak_db,
            is_clipped=is_clipped, volume_adjusted=vol_adj,
            volume_gain_db=vol_gain)

    def polish_directory(self, input_dir, output_dir=None, max_files=None):
        """Polish all audio segments in a directory."""
        if output_dir is None:
            output_dir = os.path.join(input_dir, "polished")
        exts = {".flac", ".wav", ".mp3", ".ogg", ".m4a"}
        files = []
        for ext in exts:
            files.extend(Path(input_dir).glob(f"*{ext}"))
        files = sorted(files, key=lambda x: x.name)
        if max_files:
            files = files[:max_files]
        print(f"[Polisher] Processing {len(files)} segments...")
        results = []
        modified = 0
        for f in files:
            r = self.polish(str(f), output_dir=output_dir)
            results.append(r)
            if r.was_modified:
                modified += 1
            print(r.summary())
        print(f"[Polisher] Done: {modified}/{len(files)} modified, "
              f"{len(files) - modified} already clean")
        return results
