import logging
import math
import warnings
from dataclasses import dataclass
from enum import Enum
from typing import Any

import numpy as np

from mistral_common.audio import Audio
from mistral_common.protocol.instruct.chunk import AudioChunk, AudioURLChunk, AudioURLType

logger = logging.getLogger(__name__)

# For offline streaming we're encoding the whole audio at once.
# Because the model is delayed by <transcription_delay_ms> + word_length
# we must always add a buffer of max world length to the audio in the end.
# For now we assume that there is no word that requires more than 10 tokens (0.8s)
# which might not be the case in practice but it won't affect any eval results.
OFFLINE_STREAMING_BUFFER_TOKENS = 10


def _check_mult_of(num_samples: int, mult_of: int) -> None:
    assert num_samples % mult_of == 0, f"{num_samples=} must be a multiple of {mult_of=}"


class TranscriptionFormat(str, Enum):
    r"""Transcription format.

    Should be set by the tokenizer for correct encoding.

    Attributes:
    - INSTRUCT: The instruct format.
    - STREAMING: The streaming format.
    """

    INSTRUCT = "instruct"
    STREAMING = "streaming"


@dataclass
class AudioSpectrogramConfig:
    r"""Configuration for generating an audio spectrogram.

    Attributes:
        num_mel_bins: Number of mel bins, typically 80 or 128.
        hop_length: Length of the overlapping windows for
            the STFT used to obtain the Mel Frequency coefficients, typically 160.
        window_size: Window size of the Fourier transform, typically 400.
    """

    # Number of mel bins, typically 80 or 128
    num_mel_bins: int
    # Length of the overlapping windows for the STFT used to obtain the Mel Frequency coefficients, typically 160
    hop_length: int
    # Window size of the Fourier transform, typically 400
    window_size: int

    def __post_init__(self) -> None:
        assert self.num_mel_bins > 0, self.num_mel_bins
        assert self.hop_length > 0, self.hop_length
        assert self.window_size > 0, self.window_size


@dataclass
class AudioConfig:
    r"""Configuration for audio processing.

    Attributes:
        sampling_rate: Sampling rate of the audio.
        frame_rate: Number of frames per second accepted by the tokenizer model.
        encoding_config: Configuration for audio spectrogram.
        chunk_length_s: Whether to pad an audio into multiples of chunk_length_s seconds (optional).
        voice_num_audio_tokens: Mapping from speaker voice name to number of audio tokens
            for that speaker's reference audio (optional, only for TTS).
    """

    sampling_rate: int
    # number of frames per second accepted by the tokenizer model.
    frame_rate: float
    encoding_config: AudioSpectrogramConfig
    # Whether to pad an audio into multiples of chunk_length_s seconds
    chunk_length_s: float | None = None

    # If we're in streaming or non-streaming
    transcription_format: TranscriptionFormat = TranscriptionFormat.INSTRUCT

    # delay between the audio stream and text stream
    transcription_delay_ms: float | None = None

    # only relevant for streaming
    streaming_look_ahead_ms: float | None = None
    streaming_look_back_ms: float | None = None

    streaming_n_left_pad_tokens: int | None = None

    voice_num_audio_tokens: dict[str, int] | None = None

    def __post_init__(self) -> None:
        assert self.frame_rate > 0, self.frame_rate
        assert self.sampling_rate > 0, self.sampling_rate

        if self.chunk_length_s is not None:
            assert self.chunk_length_s > 0, self.chunk_length_s
            assert self.chunk_frames > 0, (
                f"chunk_length_s and sampling_rate must both be > 0, got {self.chunk_length_s} and {self.sampling_rate}"
            )

        if not self.is_streaming:
            # make sure streaming params are only set for streaming use case
            assert self.transcription_delay_ms is None, f"{self.transcription_delay_ms=} must be None."
            assert self.streaming_look_ahead_ms is None, f"{self.streaming_look_ahead_ms=} must be None."
            assert self.streaming_look_back_ms is None, f"{self.streaming_look_back_ms=} must be None."
            assert self.streaming_n_left_pad_tokens is None, f"{self.streaming_n_left_pad_tokens=} must be None."

        if self.is_streaming:
            assert self.transcription_delay_ms is not None, f"{self.transcription_delay_ms=} must be set."
            assert self.streaming_look_ahead_ms is not None, f"{self.streaming_look_ahead_ms=} must be set."
            assert self.streaming_look_back_ms is not None, f"{self.streaming_look_back_ms=} must be set."
            assert self.streaming_n_left_pad_tokens is not None, f"{self.streaming_n_left_pad_tokens=} must be set."

            frame_duration_ms = 1000.0 / self.frame_rate

            assert self.transcription_delay_ms > 0, "{self.transcription_delay_ms=} must be > 0"
            assert self.transcription_delay_ms % frame_duration_ms == 0, (
                f"{self.transcription_delay_ms=} must be a multiple of {frame_duration_ms=}"
            )
            assert self.chunk_length_s is None, f"{self.chunk_length_s=} cannot be set in streaming."

    @property
    def is_streaming(self) -> bool:
        return self.transcription_format == TranscriptionFormat.STREAMING

    def num_audio_tokens(self, audio_len: int) -> int:
        if audio_len % self.encoding_config.hop_length != 0:
            audio_len = math.ceil(audio_len / self.encoding_config.hop_length - 1)
        else:
            audio_len = audio_len // self.encoding_config.hop_length

        return math.ceil(audio_len / self.audio_length_per_tok)

    @property
    def num_delay_tokens(self) -> int:
        # TODO(Patrick) - delete in 1.11.0
        # only used in vLLM in voxtral_realtime.py
        warnings.warn("Use get_num_delay_tokens instead of num_delay_tokens", DeprecationWarning)
        return self.get_num_delay_tokens()

    def get_num_delay_tokens(self, transcription_delay_ms: float | None = None) -> int:
        assert self.is_streaming, f"Can't call get_num_delay_tokens if {self.is_streaming=}."
        if transcription_delay_ms is None:
            transcription_delay_ms = self.transcription_delay_ms
        # streaming pad tokens
        assert transcription_delay_ms is not None, f"Can't call get_num_delay_tokens if {transcription_delay_ms=}."
        return self.num_audio_tokens(self.delay_len(transcription_delay_ms))

    def delay_len(self, transcription_delay_ms: float) -> int:
        return int(transcription_delay_ms / 1000.0 * self.sampling_rate)

    @property
    def frame_duration_ms(self) -> float:
        return 1000.0 / self.frame_rate

    @property
    def chunk_frames(self) -> int:
        r"""Calculate the number of frames per chunk."""
        assert self.chunk_length_s is not None, f"Can't call chunk_frames if {self.chunk_length_s=}."
        return int(self.chunk_length_s * self.sampling_rate)

    @property
    def raw_audio_length_per_tok(self) -> int:
        return int(self.sampling_rate // self.frame_rate)

    @property
    def audio_length_per_tok(self) -> int:
        r"""Calculate the length of audio per token."""
        downsample_factor = float(self.raw_audio_length_per_tok)
        downsample_factor /= self.encoding_config.hop_length
        return int(downsample_factor)

    def n_right_pad_tokens(self, transcription_delay_ms: float | None = None) -> int:
        assert self.is_streaming, f"Can't call n_right_pad_tokens if {self.is_streaming=}."
        # we need to pad on the right to ensure the models transcribes
        # - the induced delay on the prefill step (num_delay_tokens)
        # - the BOS token (1)
        # - a heuristic that defines a max token length for a single word
        #   (OFFLINE_STREAMING_BUFFER_TOKENS)
        return (self.get_num_delay_tokens(transcription_delay_ms) + 1) + OFFLINE_STREAMING_BUFFER_TOKENS

    @property
    def n_left_pad_tokens(self) -> int:
        assert self.is_streaming, f"Can't call n_left_pad_tokens if {self.is_streaming=}."
        # We also pad on the left as this has shown to improve performance
        # simply by giving the model "more compute", we also add
        # - the same induced delay
        # - OFFLINE_STREAMING_BUFFER_TOKENS
        assert self.streaming_n_left_pad_tokens is not None, f"{self.streaming_n_left_pad_tokens=} must be set."
        return self.streaming_n_left_pad_tokens


@dataclass
class AudioEncoding:
    r"""Encapsulates the tokens and audio data for an audio chunk.

    Attributes:
        tokens: Text tokens corresponding to this audio chunk.
        audio: Original audio waveform data, or None when using a preset voice
            (no reference audio to forward to the model).
    """

    tokens: list[int]
    audio: Audio | None


@dataclass
class SpecialAudioIDs:
    r"""Special text tokens corresponding to audio token sequence.

    Attributes:
        audio: Token representing audio.
        begin_audio: Token representing the beginning of audio.
        streaming_pad: Token representing streaming pad of audio. Only relevant for steaming models.
        text_to_audio: Token representing intent to convert text to audio.
        audio_to_text: Token representing intent to convert audio to text.
    """

    audio: int | None
    begin_audio: int | None
    streaming_pad: int | None
    text_to_audio: int | None
    audio_to_text: int | None


class AudioEncoder:
    r"""Encodes audio chunks into a format suitable for further processing.

    Attributes:
        audio_config: Configuration for audio processing.
        encoding_config: Configuration for audio spectrogram.
        special_ids: Special tokens for audio encoding.
    """

    def __init__(self, audio_config: AudioConfig, special_ids: SpecialAudioIDs) -> None:
        self.audio_config = audio_config
        self.encoding_config = audio_config.encoding_config
        self.special_ids = special_ids

    def pad(
        self,
        audio_array: np.ndarray,
        sampling_rate: int,
        transcription_delay_ms: float | None = None,
        **kwargs: Any,
    ) -> np.ndarray:
        r"""Pad the audio array to the desired length.

        Args:
            audio_array: Audio data as a numpy array.
            sampling_rate: Sampling rate of the audio.
            transcription_delay_ms (optional): Delay in milliseconds for transcription.

        Returns:
            Padded audio array.
        """
        # TODO(Patrick) - remove **kwargs as it's just there to swallow deprecated
        # keyword args from voxtral_realtime in vLLM. It was
        # relevant for the release. Remove in mistral_common version 1.11
        if self.audio_config.chunk_length_s:
            next_multiple_of_chunk_frames = self.next_multiple_of_chunk_frames(audio_array.shape[-1], sampling_rate)
            audio_array = np.pad(audio_array, (0, next_multiple_of_chunk_frames - audio_array.shape[-1]))
        elif self.audio_config.is_streaming:
            left_pad, right_pad = self._get_streaming_pad(audio_array.shape[-1], transcription_delay_ms)
            # we pad both left & right as this leads to better performance
            audio_array = np.pad(audio_array, (left_pad, right_pad))
        elif (
            isinstance(self.encoding_config, AudioSpectrogramConfig)
            and audio_array.shape[-1] < self.encoding_config.window_size
        ):
            # minimum length for audios is at least one spectrogram frame
            audio_array = np.pad(audio_array, (0, self.encoding_config.window_size - audio_array.shape[-1]))

        return audio_array

    def get_padding_audio(self, transcription_delay_ms: float | None = None) -> tuple[Audio, Audio]:
        r"""Gets left and right padding for realtime audio models.

        Args:
            transcription_delay_ms (optional): Delay in milliseconds for transcription.

        Returns:
            Tuple of left and right padding for realtime audio models.
        """

        left_pad, right_pad = self._get_streaming_pad(0, transcription_delay_ms)
        left_pad_audio = Audio(
            audio_array=np.zeros(left_pad, dtype=np.float32),
            sampling_rate=self.audio_config.sampling_rate,
            format="wav",
        )
        right_pad_audio = Audio(
            audio_array=np.zeros(right_pad, dtype=np.float32),
            sampling_rate=self.audio_config.sampling_rate,
            format="wav",
        )
        return left_pad_audio, right_pad_audio

    def _get_streaming_pad(self, num_samples: int, transcription_delay_ms: float | None = None) -> tuple[int, int]:
        # let's make sure the audio is a multiple of one "frame" token
        mult_of = self.audio_config.raw_audio_length_per_tok

        right_pad = int((mult_of - (num_samples % mult_of)) % mult_of)

        _extra_right_pad_tokens = self.audio_config.n_right_pad_tokens(transcription_delay_ms)
        _extra_right_pad_samples = int(mult_of * _extra_right_pad_tokens)
        _check_mult_of(_extra_right_pad_samples, mult_of)
        right_pad += _extra_right_pad_samples

        # We also pad on the left as this has shown to improve performance
        # simply by giving the model "more compute", we also add
        # - the same induced delay
        # - OFFLINE_STREAMING_BUFFER_TOKENS
        _extra_left_pad_tokens = self.audio_config.n_left_pad_tokens
        left_pad = int(mult_of * _extra_left_pad_tokens)
        _check_mult_of(left_pad, mult_of)

        return left_pad, right_pad

    def next_multiple_of_chunk_frames(self, audio_array_len: int, sampling_rate: int) -> int:
        r"""Calculate the next multiple of chunk frames.

        Args:
            audio_array_len: Length of the audio array.
            sampling_rate: Sampling rate of the audio.

        Returns:
            The next multiple of chunk frames.
        """
        assert sampling_rate == self.audio_config.sampling_rate, (
            f"Expected {sampling_rate=} to be {self.audio_config.sampling_rate=}"
        )
        assert self.audio_config.chunk_length_s is not None, (
            f"Can't call next_multiple_of_chunk_frames if {self.audio_config.chunk_length_s=}."
        )

        return math.ceil(audio_array_len / self.audio_config.chunk_frames) * self.audio_config.chunk_frames

    def encode_streaming_tokens(self, transcription_delay_ms: float | None = None) -> list[int]:
        r"""Encode the streaming tokens given a transcription delay."""
        assert isinstance(self.audio_config.encoding_config, AudioSpectrogramConfig), (
            f"Audio encoder must be spectrogram encoder, got {self.audio_config.encoding_config=}"
        )
        assert self.audio_config.transcription_delay_ms is not None

        # streaming pad tokens consist of silence we pad on left + delay tokens
        stream_pad_prefix_len = self.audio_config.n_left_pad_tokens + self.audio_config.get_num_delay_tokens(
            transcription_delay_ms
        )
        tokens = [self.streaming_pad] * stream_pad_prefix_len

        return tokens

    def _encode_audio_tokens(self, signal_length: int) -> list[int]:
        # for spectrogram-based models, the waveform is downsampled by the hop_length when computing the log-mel
        if signal_length % self.encoding_config.hop_length != 0:
            signal_length = math.ceil(signal_length / self.encoding_config.hop_length - 1)
        else:
            signal_length = signal_length // self.encoding_config.hop_length

        num_audio_tokens = math.ceil(signal_length / self.audio_config.audio_length_per_tok)
        tokens = [self.begin_audio_token] + [self.audio_token] * num_audio_tokens

        return tokens

    def encode_audio(self, audio: Audio, transcription_delay_ms: float | None = None) -> AudioEncoding:
        r"""Encode an audio optionally with transcription delay."""
        audio.resample(self.audio_config.sampling_rate)
        audio.audio_array = self.pad(audio.audio_array, self.audio_config.sampling_rate, transcription_delay_ms)

        if self.audio_config.transcription_format == TranscriptionFormat.STREAMING:
            tokens = self.encode_streaming_tokens(transcription_delay_ms)
        else:
            tokens = self._encode_audio_tokens(audio.audio_array.shape[0])

        return AudioEncoding(
            tokens=tokens,
            audio=audio,
        )

    def _encode_audio_tokens_for_speech_request(self, num_audio_tokens: int) -> list[int]:
        r"""Build the token sequence for a speech request's audio segment.

        Args:
            num_audio_tokens: Number of audio placeholder tokens to emit.

        Returns:
            List of token IDs: [BEGIN_AUDIO, AUDIO * num_audio_tokens].
        """
        tokens = []
        tokens.append(self.begin_audio_token)
        tokens.extend([self.audio_token] * num_audio_tokens)
        return tokens

    def _get_num_audio_token_for_speech_request(self, audio_length: int) -> int:
        r"""Compute the number of audio tokens needed for a given audio length.

        Args:
            audio_length: Number of audio samples.

        Returns:
            Number of audio tokens (includes +1 for END_OUTPUT_AUDIO).
        """
        return (
            math.ceil((audio_length / self.audio_config.sampling_rate) * self.audio_config.frame_rate) + 1
        )  # +1 for eoa (END_OUTPUT_AUDIO)

    def encode_audio_for_speech_request(self, audio: Audio | None, voice: str | None) -> AudioEncoding:
        r"""Encode audio or voice preset into an AudioEncoding for speech synthesis.

        Either ``audio`` (reference audio for voice cloning) or ``voice`` (preset name)
        must be provided. When ``audio`` is given it takes precedence.

        Args:
            audio: Reference audio waveform, or None to use a voice preset.
            voice: Preset voice name (e.g. 'Neutral Male', 'Neutral Female'), or None when using ref audio.

        Returns:
            AudioEncoding containing the token sequence and optional audio data.
        """
        assert audio is not None or voice is not None, (
            f"Either audio or voice must be defined to encode audio, got {audio=} and {voice=}"
        )

        if audio is not None:
            audio.resample(self.audio_config.sampling_rate)
            num_audio_tokens = self._get_num_audio_token_for_speech_request(len(audio.audio_array))
        else:
            assert self.audio_config.voice_num_audio_tokens is not None, (
                "voice_num_audio_tokens must be set in audio config to use voice-based speech requests"
            )
            assert voice is not None and voice in self.audio_config.voice_num_audio_tokens, (
                f"Unknown voice {voice!r}, expected one of {list(self.audio_config.voice_num_audio_tokens)}"
            )
            num_audio_tokens = self.audio_config.voice_num_audio_tokens[voice]
        tokens = self._encode_audio_tokens_for_speech_request(num_audio_tokens)

        return AudioEncoding(
            tokens=tokens,
            audio=audio,
        )

    def _encode_audio_chunk(self, content: AudioChunk) -> AudioEncoding:
        audio = Audio.from_raw_audio(content.input_audio)
        return self.encode_audio(audio)

    def _encode_audio_url_chunk(self, content: AudioURLChunk) -> AudioEncoding:
        url_type = content.get_url_type()

        if url_type in {AudioURLType.file, AudioURLType.file_uri}:
            audio = Audio.from_file(content.url)
        elif url_type == AudioURLType.url:
            audio = Audio.from_url(content.url)
        else:
            audio = Audio.from_base64(content.url)

        return self.encode_audio(audio)

    def __call__(self, content: AudioChunk | AudioURLChunk) -> AudioEncoding:
        r"""Call the encoder on an audio chunk or URL chunk.

        Args:
            content: Audio or URL chunk to encode.

        Returns:
            Encoded audio data and tokens.
        """
        if isinstance(content, AudioURLChunk):
            return self._encode_audio_url_chunk(content)
        elif isinstance(content, AudioChunk):
            return self._encode_audio_chunk(content)
        else:
            raise ValueError(f"Unsupported content type: {type(content)}")

    @property
    def audio_token(self) -> int:
        r"""Get the audio token."""
        assert self.special_ids.audio is not None, f"{self.special_ids.audio=} must be set."
        return self.special_ids.audio

    @property
    def begin_audio_token(self) -> int:
        r"""Get the begin audio token."""
        assert self.special_ids.begin_audio is not None, f"{self.special_ids.begin_audio=} must be set."
        return self.special_ids.begin_audio

    @property
    def streaming_pad(self) -> int:
        r"""Get the streaming pad token."""
        assert self.special_ids.streaming_pad is not None, f"{self.special_ids.streaming_pad=} must be set."
        return self.special_ids.streaming_pad

    @property
    def text_to_audio_token(self) -> int:
        r"""Get the text_to_audio token."""
        assert self.special_ids.text_to_audio is not None, f"{self.special_ids.text_to_audio=} must be set."
        return self.special_ids.text_to_audio

    @property
    def audio_to_text_token(self) -> int:
        r"""Get the audio_to_text token."""
        assert self.special_ids.audio_to_text is not None, f"{self.special_ids.audio_to_text=} must be set."
        return self.special_ids.audio_to_text
