import numpy as np
import soundfile as sf
from pathlib import Path

OUT = Path("/home/ubuntu/sam_audio_test")

GAP_SEC = 0.15
THRESHOLD_DB = -40
MIN_SPEECH_SEC = 0.1
MIN_SILENCE_SEC = 0.3

def strip_silence(input_path, output_path):
    audio, sr = sf.read(str(input_path))
    if audio.ndim == 1:
        audio = audio.reshape(-1, 1)

    mono = audio.mean(axis=1)
    threshold = 10 ** (THRESHOLD_DB / 20)

    frame_len = int(0.02 * sr)
    hop = frame_len // 2
    n_frames = (len(mono) - frame_len) // hop + 1

    rms = np.array([
        np.sqrt(np.mean(mono[i * hop : i * hop + frame_len] ** 2))
        for i in range(n_frames)
    ])

    is_speech = rms > threshold

    min_speech_frames = int(MIN_SPEECH_SEC / (hop / sr))
    min_silence_frames = int(MIN_SILENCE_SEC / (hop / sr))

    for i in range(len(is_speech)):
        if not is_speech[i]:
            start = i
            while i < len(is_speech) and not is_speech[i]:
                i += 1
            if (i - start) < min_silence_frames:
                is_speech[start:i] = True

    i = 0
    while i < len(is_speech):
        if is_speech[i]:
            start = i
            while i < len(is_speech) and is_speech[i]:
                i += 1
            if (i - start) < min_speech_frames:
                is_speech[start:i] = False
        else:
            i += 1

    segments = []
    i = 0
    while i < len(is_speech):
        if is_speech[i]:
            start = i
            while i < len(is_speech) and is_speech[i]:
                i += 1
            s_sample = max(0, start * hop - int(0.02 * sr))
            e_sample = min(len(audio), i * hop + int(0.02 * sr))
            segments.append(audio[s_sample:e_sample])
        else:
            i += 1

    gap = np.zeros((int(GAP_SEC * sr), audio.shape[1]))
    parts = []
    for i, seg in enumerate(segments):
        parts.append(seg)
        if i < len(segments) - 1:
            parts.append(gap)

    result = np.concatenate(parts, axis=0)
    sf.write(str(output_path), result, sr)

    orig_dur = len(audio) / sr
    new_dur = len(result) / sr
    return orig_dur, new_dur, len(segments)


files = [
    ("woman_speaking_target.wav", "woman_speaking_compact.wav"),
    ("speech_target.wav", "speech_compact.wav"),
]

for inp, outp in files:
    orig, new, segs = strip_silence(OUT / inp, OUT / outp)
    print(f"{inp}:")
    print(f"  {orig:.1f}s -> {new:.1f}s ({segs} segments, {(1-new/orig)*100:.0f}% removed)")
    print(f"  -> {outp}")
