from abc import ABC, abstractmethod
from enum import Enum
from pathlib import Path
from typing import Any, Generic, TypeVar

import numpy as np
from pydantic import ConfigDict, Field

from mistral_common.audio import Audio
from mistral_common.base import MistralBase
from mistral_common.protocol.fim.request import FIMRequest
from mistral_common.protocol.instruct.chunk import UserContentChunk
from mistral_common.protocol.instruct.messages import (
    AssistantMessageType,
    UserMessage,
)
from mistral_common.protocol.instruct.request import InstructRequest
from mistral_common.protocol.instruct.tool_calls import Tool
from mistral_common.protocol.speech.request import SpeechRequest
from mistral_common.protocol.transcription.request import TranscriptionRequest
from mistral_common.tokens.tokenizers.audio import AudioEncoder
from mistral_common.tokens.tokenizers.image import ImageEncoder
from mistral_common.tokens.tokenizers.model_settings_builder import ModelSettingsBuilder


class UserMessagePosition(str, Enum):
    """Where to encode available tools"""

    first = "first"
    last = "last"


class SpecialTokens(str, Enum):
    r"""Enum of special tokens used in the tokenizer.

    Attributes:
        unk: The unknown token.
        bos: The beginning of string token.
        eos: The end of string token.
        begin_inst: The beginning of instruction token.
        end_inst: The end of instruction token.
        begin_tools: The beginning of tools token.
        end_tools: The end of tools token.
        begin_tool_results: The beginning of tool results token.
        end_tool_results: The end of tool results token.
        tool_calls: The tool calls token.
        img: The image token.
        pad: The pad token.
        img_break: The image break token.
        img_end: The image end token.
        prefix: The prefix token for FIM.
        middle: The middle token for FIM.
        suffix: The suffix token for FIM.
        begin_system: The beginning of system prompt token.
        end_system: The end of system prompt token.
        begin_tool_content: The beginning of tool content token.
        args: The args token.
        call_id: The call id token.
        audio: The audio token.
        begin_audio: The beginning of audio token.
        transcribe: The transcribe token.
        begin_think: The beginning of think token.
        end_think: The end of think token.
        streaming_pad: The streaming pad token.
        streaming_word: The streaming word token.
        text_to_audio: The text to audio token.
        audio_to_text: The audio to text token.

    Examples:
        >>> unk = SpecialTokens.unk
    """

    unk = "<unk>"
    bos = "<s>"
    eos = "</s>"
    begin_inst = "[INST]"
    end_inst = "[/INST]"
    begin_tools = "[AVAILABLE_TOOLS]"
    end_tools = "[/AVAILABLE_TOOLS]"
    begin_tool_results = "[TOOL_RESULTS]"
    end_tool_results = "[/TOOL_RESULTS]"
    tool_calls = "[TOOL_CALLS]"
    img = "[IMG]"
    pad = "<pad>"
    img_break = "[IMG_BREAK]"
    img_end = "[IMG_END]"
    prefix = "[PREFIX]"
    middle = "[MIDDLE]"
    suffix = "[SUFFIX]"
    begin_system = "[SYSTEM_PROMPT]"
    end_system = "[/SYSTEM_PROMPT]"
    begin_tool_content = "[TOOL_CONTENT]"
    args = "[ARGS]"
    call_id = "[CALL_ID]"
    audio = "[AUDIO]"
    begin_audio = "[BEGIN_AUDIO]"
    transcribe = "[TRANSCRIBE]"
    begin_think = "[THINK]"
    end_think = "[/THINK]"
    streaming_pad = "[STREAMING_PAD]"
    streaming_word = "[STREAMING_WORD]"
    text_to_audio = "[NEXT_AUDIO_TEXT]"
    audio_to_text = "[REPEAT_AUDIO_TEXT]"
    begin_model_settings = "[MODEL_SETTINGS]"
    end_model_settings = "[/MODEL_SETTINGS]"


class SpecialTokenPolicy(str, Enum):
    r"""What to do with special tokens when encoding/decoding.

    Attributes:
        IGNORE: Ignore special tokens.
        KEEP: Keep special tokens.
        RAISE: Raise an error if special tokens are found.
    """

    IGNORE = "ignore"
    KEEP = "keep"
    RAISE = "raise"

    @classmethod
    def _missing_(cls, value: Any) -> Any:
        # Backward compatibility of int values.
        match value:
            case 0:
                return SpecialTokenPolicy.IGNORE
            case 1:
                return SpecialTokenPolicy.KEEP
            case 2:
                return SpecialTokenPolicy.RAISE
        return super()._missing_(value)


class TokenizerVersion(str, Enum):
    r"""Enum of tokenizer versions.

    Allow to distinguish between different versions of the tokenizer and maintain backward compatibility.

    Attributes:
        v1: The first version of the tokenizer.
        v2: The second version of the tokenizer that includes special control tokens [INST], [\INST].
        v3: The third version of the tokenizer that includes improved function calling.
        v7: The seventh version of the tokenizer that includes improved system prompt and function calling.
        v11: The eleventh version of the tokenizer that includes improved function calling.
        v13: The thirteenth version of the tokenizer that includes no call id tokenization and better prompt caching.

    Examples:
        >>> version = TokenizerVersion.v1
    """

    @property
    def _version_num(self) -> int:
        return int(self.value[1:])

    @property
    def supports_model_settings(self) -> bool:
        return self >= TokenizerVersion.v15

    def __lt__(self, other: "str | TokenizerVersion") -> bool:
        if isinstance(other, str):
            other = TokenizerVersion(other)
        return self._version_num < other._version_num

    def __le__(self, other: "str | TokenizerVersion") -> bool:
        if isinstance(other, str):
            other = TokenizerVersion(other)
            return self._version_num <= other._version_num

    def __gt__(self, other: "str | TokenizerVersion") -> bool:
        if isinstance(other, str):
            other = TokenizerVersion(other)
            return self._version_num > other._version_num

    def __ge__(self, other: "str | TokenizerVersion") -> bool:
        if isinstance(other, str):
            other = TokenizerVersion(other)
            return self._version_num >= other._version_num

    v1 = "v1"  # vocab_size = 32000
    v2 = "v2"  # vocab_size = 32768 with special control tokens [INST], [\INST]
    v3 = "v3"  # vocab_size = 32768 (spm) OR 131072 (tekken) with improved function calling
    v7 = "v7"  # vocab_size = 32768 (spm) or 131072 (tekken) with improved system prompt and function calling
    v11 = "v11"  # 131072 (tekken) with improved function calling
    v13 = "v13"  # 131072 (tekken) with no call id and better prompt caching
    v15 = "v15"  # 131072 (tekken) with model settings


class Tokenized(MistralBase):
    r"""A tokenized [`InstructRequest`][mistral_common.protocol.instruct.request.InstructRequest].

    Attributes:
        tokens: The token ids.
        text: The text representation of the tokens.
        prefix_ids: The prefix ids for FIM.
        images: The loaded images associated with the tokens.

    Examples:
        >>> tokenized = Tokenized(tokens=[1, 2, 3], text="Hello world", prefix_ids=[1], images=[])
    """

    model_config = ConfigDict(arbitrary_types_allowed=True)
    tokens: list[int]
    text: str | None = None
    prefix_ids: list[int] | None = None
    images: list[np.ndarray] = Field(default_factory=list)
    audios: list[Audio] = Field(default_factory=list)


class Tokenizer(ABC):
    @property
    @abstractmethod
    def n_words(self) -> int:
        r"""Vocabulary size of the tokenizer."""

    @property
    @abstractmethod
    def special_ids(self) -> set[int]:
        r"""Ids of the special tokens."""

    @property
    @abstractmethod
    def num_special_tokens(self) -> int:
        r"""The number of special tokens of the tokenizer."""

    @property
    @abstractmethod
    def model_settings_builder(self) -> ModelSettingsBuilder | None:
        r"""The model settings builder, or None if unsupported by this version."""

    @abstractmethod
    def vocab(self) -> list[str]:
        r"""All tokens in the vocabulary as strings."""

    @abstractmethod
    def id_to_piece(self, token_id: int) -> str:
        r"""Convert a token id to the token str."""

    @property
    @abstractmethod
    def bos_id(self) -> int:
        r"""id of the Beginning of String token."""

    @property
    @abstractmethod
    def eos_id(self) -> int:
        r"""id of the End of String token."""

    @property
    @abstractmethod
    def pad_id(self) -> int:
        r"""id of the Pad token."""

    @property
    @abstractmethod
    def unk_id(self) -> int:
        r"""id of the Unk token."""

    @abstractmethod
    def encode(self, s: str, bos: bool, eos: bool) -> list[int]:
        """Convert a string to a list of token ids."""

    @abstractmethod
    def decode(self, tokens: list[int], special_token_policy: SpecialTokenPolicy = SpecialTokenPolicy.IGNORE) -> str:
        r"""Decode the token ids to a string.

        Args:
            tokens: The token ids to decode.
            special_token_policy: The policy to use for special tokens.

        Returns:
            The decoded string.
        """

    @abstractmethod
    def get_special_token(self, s: str) -> int:
        r"""Get the id of a control token."""

    @abstractmethod
    def is_special(self, token: int | np.integer | str) -> bool:
        r"""Check if token id or token str is a special token."""

    @property
    @abstractmethod
    def version(self) -> TokenizerVersion:
        r"""Get the version of the tokenizer."""

    @abstractmethod
    def _to_string(self, tokens: list[int]) -> str: ...

    @property
    @abstractmethod
    def file_path(self) -> Path:
        r"""The file path of the tokenizer."""
        ...


InstructRequestType = TypeVar("InstructRequestType", bound=InstructRequest)
FIMRequestType = TypeVar("FIMRequestType", bound=FIMRequest)
TokenizedType = TypeVar("TokenizedType", bound=Tokenized)


class InstructTokenizer(Generic[InstructRequestType, FIMRequestType, TokenizedType, AssistantMessageType]):
    r"""Base class for instruct tokenizers.

    Attributes:
        tokenizer: The tokenizer to use.
        image_encoder: The image encoder to use if any.
        audio_encoder: The audio encoder to use if any.
    """

    tokenizer: Tokenizer
    image_encoder: ImageEncoder | None
    audio_encoder: AudioEncoder | None

    @property
    def version(self) -> TokenizerVersion:
        r"""The version of the tokenizer."""
        return self.tokenizer.version

    def __init__(
        self, tokenizer: Tokenizer, image_encoder: ImageEncoder | None, audio_encoder: AudioEncoder | None
    ) -> None:
        r"""Initialize the instruct tokenizer.

        Args:
            tokenizer: The tokenizer to use.
            image_encoder: The image encoder to use if any.
            audio_encoder: The audio encoder to use if any.
        """

    @abstractmethod
    def encode_instruct(self, request: InstructRequestType) -> TokenizedType:
        r"""Instruct request to Tokenized object

        Args:
            request: The instruct request to encode.

        Returns:
            The tokenized instruct request.
        """

    @abstractmethod
    def encode_transcription(self, request: TranscriptionRequest) -> TokenizedType:
        r"""
        Encodes an audio transcription request into a tokenized format.

        This method processes a transcription request containing audio data,
        encodes the user message, and returns the tokenized output.

        Args:
            request: The transcription request object containing
                the audio data to be encoded.

        Returns:
            Tokenized: The tokenized representation of the audio data, including processed audio and tokens
        """

    @abstractmethod
    def encode_speech_request(self, request: SpeechRequest) -> TokenizedType:
        r"""Encodes a speech synthesis request into a tokenized format.

        This method processes a speech request containing text input and
        optional reference audio or voice preset, and returns the tokenized output.

        Args:
            request: The speech request object containing the text and voice/audio data.

        Returns:
            Tokenized: The tokenized representation of the speech request.
        """

    @abstractmethod
    def decode(self, tokens: list[int], special_token_policy: SpecialTokenPolicy) -> str:
        r"""Convert token ids to string

        Args:
            tokens: The token ids to decode.
            special_token_policy: The policy to use for special tokens.

        Returns:
            The decoded string.
        """

    @abstractmethod
    def encode_fim(self, request: FIMRequestType) -> TokenizedType:
        r"""FIM request to Tokenized object

        Args:
            request: The FIM request to encode.

        Returns:
            The tokenized FIM request.
        """

    @abstractmethod
    def encode_user_message(
        self,
        message: UserMessage,
        available_tools: list[Tool] | None,
        is_last: bool,
        is_first: bool,
        system_prompt: str | None = None,
        force_img_first: bool = False,
    ) -> tuple[list[int], list[np.ndarray], list[Audio]]:
        r"""Encode a user message.

        Args:
            message: The user message to encode.
            available_tools: The available tools.
            is_last: Whether the message is the last one.
            is_first: Whether the message is the first one.
            system_prompt: The system prompt.
            force_img_first: Whether to force the image to be first.

        Returns:
            The encoded tokens and images.
        """
        ...

    @abstractmethod
    def encode_user_content(
        self,
        content: str | list[UserContentChunk],
        is_last: bool,
        system_prompt: str | None = None,
        force_img_first: bool = False,
    ) -> tuple[list[int], list[np.ndarray], list[Audio]]:
        r"""Encode a user content.

        Args:
            content: The user content to encode.
            is_last: Whether the content is the last one.
            system_prompt: The system prompt.
            force_img_first: Whether to force the image to be first.

        Returns:
            The encoded tokens and images.
        """
        ...

    @abstractmethod
    def _to_string(self, tokens: list[int]) -> str: ...
