import numpy as np

from veena3modal.audio.crossfade import crossfade_bytes_int16


def _pcm_bytes(value: int, samples: int) -> bytes:
    arr = (np.ones(samples, dtype=np.int16) * value).astype(np.int16)
    return arr.tobytes()


def test_crossfade_bytes_int16_initial_holds_tail_and_emits_body():
    sample_rate_hz = 1000
    crossfade_ms = 10  # -> 10 samples -> 20 bytes
    crossfade_bytes = int(sample_rate_hz * crossfade_ms / 1000) * 2

    new_audio = _pcm_bytes(100, samples=50)  # 100 bytes
    to_emit, tail = crossfade_bytes_int16(None, new_audio, sample_rate_hz=sample_rate_hz, crossfade_ms=crossfade_ms)

    assert len(tail) == crossfade_bytes
    assert len(to_emit) == len(new_audio) - crossfade_bytes
    assert len(to_emit) % 2 == 0 and len(tail) % 2 == 0


def test_crossfade_bytes_int16_full_overlap_emits_crossfaded_plus_middle_and_holds_new_tail():
    sample_rate_hz = 1000
    crossfade_ms = 10  # -> 10 samples -> 20 bytes
    crossfade_bytes = int(sample_rate_hz * crossfade_ms / 1000) * 2

    prev_tail = _pcm_bytes(100, samples=crossfade_bytes // 2)
    new_audio = _pcm_bytes(200, samples=50)  # 100 bytes

    to_emit, new_tail = crossfade_bytes_int16(prev_tail, new_audio, sample_rate_hz=sample_rate_hz, crossfade_ms=crossfade_ms)

    assert len(new_tail) == crossfade_bytes
    # Emit should be overlap + middle (everything except new tail)
    assert len(to_emit) == len(new_audio) - crossfade_bytes
    assert len(to_emit) % 2 == 0 and len(new_tail) % 2 == 0


def test_crossfade_bytes_int16_partial_overlap_when_new_chunk_is_too_short():
    sample_rate_hz = 1000
    crossfade_ms = 10  # -> 10 samples -> 20 bytes
    crossfade_bytes = int(sample_rate_hz * crossfade_ms / 1000) * 2

    prev_tail = _pcm_bytes(100, samples=crossfade_bytes // 2)
    new_audio = _pcm_bytes(200, samples=5)  # 10 bytes (< crossfade_bytes)

    to_emit, new_tail = crossfade_bytes_int16(prev_tail, new_audio, sample_rate_hz=sample_rate_hz, crossfade_ms=crossfade_ms)

    # Emit should at least include the entire previous tail length in this partial case.
    assert len(to_emit) >= len(prev_tail)
    assert len(new_tail) <= crossfade_bytes
    assert len(to_emit) % 2 == 0 and len(new_tail) % 2 == 0


