import dataclasses
from typing import TYPE_CHECKING, Generator, Iterator, List, Optional, Tuple, Union

from .transformers import TransformerTokenizer

if TYPE_CHECKING:
    import mlx.core as mx
    import mlx.nn as nn
    from transformers import PreTrainedTokenizer

    from outlines.generate.api import GenerationParameters, SamplingParameters
    from outlines.processors import OutlinesLogitsProcessor


class MLXLM:
    """
    Represents an `mlx_lm` model
    """

    def __init__(
        self,
        model: "nn.Module",
        tokenizer: "PreTrainedTokenizer",
    ):
        self.model = model
        self.mlx_tokenizer = tokenizer  # returns mlx tensors, used for encode()
        self.tokenizer = TransformerTokenizer(
            tokenizer._tokenizer
        )  # _tokenizer is HF Tokenizer

    def generate(
        self,
        prompts: Union[str, List[str]],
        generation_parameters: "GenerationParameters",
        logits_processor,
        sampling_parameters: "SamplingParameters",
    ) -> str:
        streamer = self.stream(
            prompts, generation_parameters, logits_processor, sampling_parameters
        )
        return "".join(list(streamer))

    def stream(
        self,
        prompts: Union[str, List[str]],
        generation_parameters: "GenerationParameters",
        logits_processor,
        sampling_parameters: "SamplingParameters",
    ) -> Iterator[str]:
        """Generate text using `mlx_lm`.

        Arguments
        ---------
        prompts
            A prompt or list of prompts.
        generation_parameters
            An instance of `GenerationParameters` that contains the prompt,
            the maximum number of tokens, stop sequences and seed. All the
            arguments to `SequenceGeneratorAdapter`'s `__cal__` method.
        logits_processor
            The logits processor to use when generating text.
        sampling_parameters
            An instance of `SamplingParameters`, a dataclass that contains
            the name of the sampler to use and related parameters as available
            in Outlines.
        Returns
        -------
        The generated text.
        """
        import mlx.core as mx

        max_tokens, stop_at, seed = dataclasses.astuple(generation_parameters)
        sampler, num_samples, top_p, top_k, temperature = dataclasses.astuple(
            sampling_parameters
        )
        if max_tokens is None:
            max_tokens = int(1e9)

        if not isinstance(prompts, str):
            raise NotImplementedError(
                "The `mlx-lm` library does not support batch inference."
            )
        if sampler == "beam_search":
            raise NotImplementedError(
                "The `mlx-lm` library does not support Beam Search."
            )
        if num_samples != 1:
            raise NotImplementedError(
                "The `mlx-lm` library does not allow to take several samples."
            )
        if top_k is not None:
            raise NotImplementedError("The `mlx-lm` library does not support top_k.")
        if seed is not None:
            raise NotImplementedError("The `mlx-lm` library does not support seed.")
        if stop_at is not None:
            raise NotImplementedError("The `mlx-lm` library does not support stop_at.")

        generate_kwargs = {
            "temp": temperature,
            "top_p": top_p,
            "sampler": sampler,
            "logits_processor": logits_processor,
        }

        # Adapted from
        # https://github.com/ml-explore/mlx-examples/blob/4872727/llms/mlx_lm/utils.py#L267
        prompt_tokens = mx.array(self.mlx_tokenizer.encode(prompts))

        detokenizer = self.mlx_tokenizer.detokenizer
        detokenizer.reset()

        for (token, prob), n in zip(
            self.generate_step(prompt_tokens, **generate_kwargs),
            range(max_tokens),
        ):
            if token == self.tokenizer.eos_token_id:
                break
            detokenizer.add_token(token)
            yield detokenizer.last_segment

        detokenizer.finalize()
        yield detokenizer.last_segment

    def generate_step(
        self,
        prompt: "mx.array",
        temp: Optional[float],
        top_p: Optional[float],
        sampler: str,
        logits_processor: "OutlinesLogitsProcessor",
    ) -> Generator[Tuple[int, float], None, None]:
        """
        Adapted from
        https://github.com/ml-explore/mlx-examples/blob/4872727/llms/mlx_lm/utils.py#L129

        A generator producing token ids based on the given prompt from the model.

            Args:
                prompt (mx.array): The input prompt.
                temp (float): The temperature for sampling, if 0 the argmax is used.
                  Default: ``0``.
                top_p (float, optional): Nulceus sampling, higher means model considers
                  more less likely words.
                sampler (str): The sampler string defined by SequenceGeneratorAdapter
                logits_processor (OutlinesLogitsProcessor): Augment logits before sampling.
        """
        import mlx.core as mx
        import mlx_lm

        temperature: float = temp or 1.0

        def sample(logits: "mx.array") -> Tuple["mx.array", float]:
            softmax_logits = mx.softmax(logits)

            if temperature == 0.0 or sampler == "greedy":
                token = mx.argmax(logits, axis=-1)
            elif sampler == "multinomial":
                if top_p is not None and top_p > 0 and top_p < 1.0:
                    token = mlx_lm.sample_utils.top_p_sampling(
                        logits, top_p, temperature
                    )
                else:
                    token = mx.random.categorical(logits * (1 / temperature))
            else:
                raise ValueError(f"Invalid mlx-lm sampler: `{sampler}`")

            prob = softmax_logits[0, token]
            return token, prob

        cache = mlx_lm.models.cache.make_prompt_cache(self.model)

        # kv cache contains processed input IDs, we pass the unprocessed inputs and cache to model()
        unprocessed_input_ids = prompt
        generated_ids: List[int] = []

        while True:
            logits = self.model(unprocessed_input_ids[None], cache=cache)
            logits = logits[:, -1, :]

            if logits_processor is not None:
                # convert to logits_processor 1d expectation, apply, then convert back
                logits_1d = logits.reshape(-1)
                logits_1d = logits_processor(generated_ids, logits_1d)
                logits = logits_1d.reshape(1, -1)

            new_token_single, prob = sample(logits)
            new_token = new_token_single.item()
            yield new_token, prob

            generated_ids.append(new_token)
            unprocessed_input_ids = new_token_single


def mlxlm(
    model_name: str,
    tokenizer_config: dict = {},
    model_config: dict = {},
    adapter_path: Optional[str] = None,
    lazy: bool = False,
):
    """Instantiate a model from the `mlx_lm` library and its tokenizer.

    Signature adapted from
    https://github.com/ml-explore/mlx-examples/blob/4872727/llms/mlx_lm/utils.py#L422

    Parameters
    ----------
    Args:
        path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
        tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
            Defaults to an empty dictionary.
        model_config(dict, optional): Configuration parameters specifically for the model.
            Defaults to an empty dictionary.
        adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
            to the model. Default: ``None``.
        lazy (bool): If False eval the model parameters to make sure they are
            loaded in memory before returning, otherwise they will be loaded
            when needed. Default: ``False``

    Returns
    -------
    A `MLXLM` model instance.

    """
    try:
        import mlx.core as mx
        import mlx_lm
    except ImportError:
        raise ImportError(
            "The `mlx_lm` library needs to be installed in order to use `mlx_lm` models."
        )
    if not mx.metal.is_available():
        raise RuntimeError("You cannot use `mlx_lm` without Apple Silicon (Metal)")

    model, tokenizer = mlx_lm.load(
        model_name,
        tokenizer_config=tokenizer_config,
        model_config=model_config,
        adapter_path=adapter_path,
        lazy=lazy,
    )
    return MLXLM(model, tokenizer)
