"""Tests for audio polishing pipeline."""
import numpy as np
import pytest
import soundfile as sf
import tempfile
from pathlib import Path

from src.audio_polish import (
    step1_length_split, step2_boundary_trim, step3_silence_pad,
    step4_encode, polish_segment, polish_all_segments,
    _compute_rms_profile, _find_silence_valleys,
    TrimMetadata,
)


def _make_audio(duration_s: float, sr: int = 16000, freq: float = 300.0) -> np.ndarray:
    """Generate a sine wave with some noise."""
    t = np.linspace(0, duration_s, int(sr * duration_s))
    return (0.3 * np.sin(2 * np.pi * freq * t) + 0.02 * np.random.randn(len(t))).astype(np.float32)


def _make_audio_with_silence(duration_s: float, silence_at_s: float,
                              silence_dur_s: float = 0.3, sr: int = 16000) -> np.ndarray:
    """Generate audio with a silence valley at a specific point."""
    audio = _make_audio(duration_s, sr)
    start = int(silence_at_s * sr)
    end = min(start + int(silence_dur_s * sr), len(audio))
    audio[start:end] *= 0.001
    return audio


def _write_flac(audio: np.ndarray, sr: int = 16000) -> Path:
    tmp = tempfile.NamedTemporaryFile(suffix=".flac", delete=False)
    sf.write(tmp.name, audio, sr)
    return Path(tmp.name)


class TestRMSProfile:
    def test_rms_profile_shape(self):
        audio = _make_audio(1.0)
        rms = _compute_rms_profile(audio, 16000, frame_ms=10)
        assert len(rms) == 100  # 1s / 10ms = 100 frames

    def test_silence_has_low_rms(self):
        audio = np.zeros(16000, dtype=np.float32)
        rms = _compute_rms_profile(audio, 16000)
        assert np.all(rms < 0.001)

    def test_loud_audio_has_high_rms(self):
        audio = np.ones(16000, dtype=np.float32) * 0.5
        rms = _compute_rms_profile(audio, 16000)
        assert np.all(rms > 0.4)


class TestSilenceValleys:
    def test_finds_valleys_in_silence_region(self):
        audio = _make_audio_with_silence(5.0, silence_at_s=2.5)
        rms = _compute_rms_profile(audio, 16000)
        valleys = _find_silence_valleys(rms)
        assert len(valleys) > 0


class TestStep1LengthSplit:
    def test_short_segment_discarded(self):
        audio = _make_audio(1.0)
        results = step1_length_split(audio, 16000, "short.flac")
        assert len(results) == 1
        assert results[0][1].discarded is True

    def test_normal_segment_not_split(self):
        audio = _make_audio(5.0)
        results = step1_length_split(audio, 16000, "normal.flac")
        assert len(results) == 1
        assert results[0][1].discarded is False
        assert results[0][1].was_split is False

    def test_long_segment_gets_split(self):
        audio = _make_audio_with_silence(20.0, silence_at_s=8.0)
        results = step1_length_split(audio, 16000, "long.flac")
        assert len(results) >= 2
        for chunk, meta in results:
            if not meta.discarded:
                assert len(chunk) / 16000 <= 15.5  # max + tolerance

    def test_10s_segment_not_split(self):
        audio = _make_audio(10.0)
        results = step1_length_split(audio, 16000, "10s.flac")
        assert len(results) == 1
        assert results[0][1].was_split is False


class TestStep2BoundaryTrim:
    def test_clean_boundaries_no_trim(self):
        # Start and end with silence
        sr = 16000
        silence = np.zeros(int(0.1 * sr), dtype=np.float32)
        speech = _make_audio(3.0, sr)
        audio = np.concatenate([silence, speech, silence])
        meta = TrimMetadata(original_file="clean.flac", original_duration_ms=len(audio)/sr*1000)
        trimmed, meta = step2_boundary_trim(audio, sr, meta)
        assert not meta.discarded

    def test_discarded_segment_passes_through(self):
        audio = _make_audio(1.0)
        meta = TrimMetadata(original_file="x.flac", original_duration_ms=1000, discarded=True, discard_reason="test")
        result, meta = step2_boundary_trim(audio, 16000, meta)
        assert meta.discarded is True


class TestStep3SilencePad:
    def test_adds_padding(self):
        sr = 16000
        audio = _make_audio(3.0, sr)
        meta = TrimMetadata(original_file="x.flac", original_duration_ms=3000)
        padded, meta = step3_silence_pad(audio, sr, meta)
        pad_samples = int(sr * 0.15)
        assert len(padded) == len(audio) + 2 * pad_samples
        assert meta.leading_pad_ms == 150
        assert meta.trailing_pad_ms == 150

    def test_padding_is_silence(self):
        sr = 16000
        audio = _make_audio(3.0, sr)
        meta = TrimMetadata(original_file="x.flac", original_duration_ms=3000)
        padded, meta = step3_silence_pad(audio, sr, meta)
        pad_samples = int(sr * 0.15)
        assert np.all(padded[:pad_samples] == 0)
        assert np.all(padded[-pad_samples:] == 0)


class TestStep4Encode:
    def test_encode_produces_valid_flac(self):
        audio = _make_audio(2.5)
        flac_bytes, b64 = step4_encode(audio, 16000)
        assert len(flac_bytes) > 0
        assert len(b64) > 0
        # Decode base64 back
        import base64
        decoded = base64.b64decode(b64)
        assert decoded == flac_bytes


class TestPolishSegment:
    def test_full_pipeline_on_normal_segment(self):
        audio = _make_audio(5.0)
        path = _write_flac(audio)
        try:
            results = polish_segment(path)
            valid = [r for r in results if not r.trim_meta.discarded]
            assert len(valid) >= 1
            for r in valid:
                assert len(r.flac_bytes) > 0
                assert len(r.base64_audio) > 0
                assert r.trim_meta.final_duration_ms > 0
        finally:
            path.unlink(missing_ok=True)

    def test_full_pipeline_on_short_segment(self):
        audio = _make_audio(0.5)
        path = _write_flac(audio)
        try:
            results = polish_segment(path)
            assert all(r.trim_meta.discarded for r in results)
        finally:
            path.unlink(missing_ok=True)

    def test_full_pipeline_on_long_segment(self):
        audio = _make_audio_with_silence(25.0, silence_at_s=8.0)
        path = _write_flac(audio)
        try:
            results = polish_segment(path)
            valid = [r for r in results if not r.trim_meta.discarded]
            assert len(valid) >= 2
        finally:
            path.unlink(missing_ok=True)


class TestPolishAllSegments:
    def test_batch_polish(self):
        paths = []
        for dur in [3.0, 5.0, 0.5, 8.0]:
            audio = _make_audio(dur)
            paths.append(_write_flac(audio))
        try:
            results = polish_all_segments(paths)
            valid = [r for r in results if not r.trim_meta.discarded]
            assert len(valid) >= 3  # 0.5s should be discarded
        finally:
            for p in paths:
                p.unlink(missing_ok=True)
