import numpy as np
import soundfile as sf
from pathlib import Path

INPUT = Path("/home/ubuntu/sam_audio_test/man_speaking_target.wav")
OUTPUT = Path("/home/ubuntu/sam_audio_test/man_speaking_compact.wav")

GAP_SEC = 0.15
THRESHOLD_DB = -40
MIN_SPEECH_SEC = 0.1
MIN_SILENCE_SEC = 0.3

audio, sr = sf.read(str(INPUT))
if audio.ndim == 1:
    audio = audio.reshape(-1, 1)

mono = audio.mean(axis=1)
threshold = 10 ** (THRESHOLD_DB / 20)

frame_len = int(0.02 * sr)
hop = frame_len // 2
n_frames = (len(mono) - frame_len) // hop + 1

rms = np.array([
    np.sqrt(np.mean(mono[i * hop : i * hop + frame_len] ** 2))
    for i in range(n_frames)
])

is_speech = rms > threshold

min_speech_frames = int(MIN_SPEECH_SEC / (hop / sr))
min_silence_frames = int(MIN_SILENCE_SEC / (hop / sr))

# Fill short silence gaps within speech
for i in range(len(is_speech)):
    if not is_speech[i]:
        start = i
        while i < len(is_speech) and not is_speech[i]:
            i += 1
        if (i - start) < min_silence_frames:
            is_speech[start:i] = True

# Remove short speech bursts (noise)
i = 0
while i < len(is_speech):
    if is_speech[i]:
        start = i
        while i < len(is_speech) and is_speech[i]:
            i += 1
        if (i - start) < min_speech_frames:
            is_speech[start:i] = False
    else:
        i += 1

segments = []
i = 0
while i < len(is_speech):
    if is_speech[i]:
        start = i
        while i < len(is_speech) and is_speech[i]:
            i += 1
        s_sample = max(0, start * hop - int(0.02 * sr))
        e_sample = min(len(audio), i * hop + int(0.02 * sr))
        segments.append(audio[s_sample:e_sample])
    else:
        i += 1

gap = np.zeros((int(GAP_SEC * sr), audio.shape[1]))
parts = []
for i, seg in enumerate(segments):
    parts.append(seg)
    if i < len(segments) - 1:
        parts.append(gap)

result = np.concatenate(parts, axis=0)

sf.write(str(OUTPUT), result, sr)

orig_dur = len(audio) / sr
new_dur = len(result) / sr
print(f"Original: {orig_dur:.1f}s")
print(f"Compact:  {new_dur:.1f}s ({len(segments)} speech segments)")
print(f"Removed:  {orig_dur - new_dur:.1f}s of silence ({(1 - new_dur/orig_dur)*100:.0f}%)")
print(f"Output:   {OUTPUT}")
