import math
from typing import Optional, Tuple

import numpy as np


def _equal_power_fade_curves(num_samples: int) -> Tuple[np.ndarray, np.ndarray]:
    """
    Generate equal-power crossfade curves (cosine/sine) with length num_samples.
    Returns (fade_out, fade_in) as float32 arrays in range [0.0, 1.0].
    """
    if num_samples <= 0:
        return (
            np.ones(0, dtype=np.float32),
            np.ones(0, dtype=np.float32),
        )
    # angle from 0 to pi/2
    angles = np.linspace(0.0, math.pi / 2.0, num_samples, dtype=np.float32)
    fade_out = np.cos(angles, dtype=np.float32)  # 1.0 -> 0.0
    fade_in = np.sin(angles, dtype=np.float32)   # 0.0 -> 1.0
    return fade_out, fade_in


def crossfade_bytes_int16(
    previous_tail: Optional[bytes],
    new_audio_bytes: bytes,
    *,
    sample_rate_hz: int = 16000,
    crossfade_ms: int = 50,
) -> Tuple[bytes, bytes]:
    """
    Perform equal-power crossfade between previous chunk tail and the start of a new chunk.
    Returns a tuple (to_emit, new_tail) where:
      - to_emit: bytes to stream now
      - new_tail: tail bytes to hold for the next iteration
    All audio is int16 PCM mono.
    """
    if not new_audio_bytes:
        # Nothing new to emit; just keep the existing tail
        return b"", previous_tail or b""

    bytes_per_sample = 2  # int16 mono
    crossfade_samples = int(sample_rate_hz * crossfade_ms / 1000)
    crossfade_bytes = crossfade_samples * bytes_per_sample

    # If there is no previous tail yet, hold back a tail and emit the body.
    if not previous_tail:
        if len(new_audio_bytes) <= crossfade_bytes:
            # Not enough to form a body; hold everything for now
            return b"", new_audio_bytes
        body = new_audio_bytes[:-crossfade_bytes]
        new_tail = new_audio_bytes[-crossfade_bytes:]
        return body, new_tail

    # We have a previous tail to crossfade with the start of the new chunk
    prev_tail = previous_tail

    if len(new_audio_bytes) < crossfade_bytes:
        # Crossfade as much as possible (partial overlap)
        overlap_bytes = len(new_audio_bytes) - (len(new_audio_bytes) % bytes_per_sample)
        if overlap_bytes <= 0:
            # No whole-sample overlap; just append nothing and keep tail growing
            return b"", prev_tail + new_audio_bytes

        # Use the last 'overlap_bytes' from prev_tail to overlap
        overlap_samples = overlap_bytes // bytes_per_sample
        prev_overlap = np.frombuffer(prev_tail[-overlap_bytes:], dtype=np.int16).astype(np.float32)
        new_overlap = np.frombuffer(new_audio_bytes[:overlap_bytes], dtype=np.int16).astype(np.float32)
        fade_out, fade_in = _equal_power_fade_curves(overlap_samples)

        crossfaded = (prev_overlap * fade_out + new_overlap * fade_in).astype(np.int16).tobytes()
        # Emit the part of prev_tail before overlap plus the crossfaded region
        emit = prev_tail[:-overlap_bytes] + crossfaded

        # New tail is the last up-to-crossfade_bytes from the currently available end
        # Since the entire new chunk was consumed in overlap, tail remains the last part of the crossfaded output.
        # Keep at most crossfade_bytes to limit growth.
        new_tail_candidate = (emit[-crossfade_bytes:] if len(emit) >= crossfade_bytes else emit)
        return emit, new_tail_candidate

    # Full crossfade with equal-power curves
    # The previous tail should be EXACTLY crossfade_bytes (we hold that much back each time)
    # Crossfade the entire prev_tail with the first crossfade_bytes of the new chunk
    prev_overlap = np.frombuffer(prev_tail, dtype=np.int16).astype(np.float32)
    new_overlap = np.frombuffer(new_audio_bytes[:crossfade_bytes], dtype=np.int16).astype(np.float32)
    fade_out, fade_in = _equal_power_fade_curves(crossfade_samples)
    crossfaded = (prev_overlap * fade_out + new_overlap * fade_in).astype(np.int16).tobytes()

    # Emit crossfaded overlap + the middle of the new chunk (excluding its tail)
    middle = new_audio_bytes[crossfade_bytes:-crossfade_bytes] if len(new_audio_bytes) > (2 * crossfade_bytes) else b""
    to_emit = crossfaded + middle

    # Hold back the new chunk's tail for the next overlap
    new_tail = new_audio_bytes[-crossfade_bytes:]
    return to_emit, new_tail


