import struct

from veena3modal.audio.utils import create_wav_header, add_wav_header


def test_create_wav_header_has_correct_riff_and_length_fields():
    data_size = 100
    header = create_wav_header(sample_rate=16000, channels=1, bits_per_sample=16, data_size=data_size)

    assert len(header) == 44
    assert header[:4] == b"RIFF"
    assert header[8:12] == b"WAVE"
    assert header[36:40] == b"data"

    # RIFF chunk size field at bytes 4:8 should be 36 + data_size.
    (chunk_size,) = struct.unpack("<I", header[4:8])
    assert chunk_size == 36 + data_size

    # data subchunk size at bytes 40:44 should match data_size.
    (subchunk2_size,) = struct.unpack("<I", header[40:44])
    assert subchunk2_size == data_size


def test_add_wav_header_prepends_header_and_preserves_payload():
    payload = b"\x00\x01" * 50  # 100 bytes
    wav = add_wav_header(payload, sample_rate=16000)
    assert wav[:4] == b"RIFF"
    assert wav[-len(payload):] == payload


