import dataclasses
from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, TypedDict, Union

import torch
from typing_extensions import Unpack

from outlines.generate.api import GenerationParameters, SamplingParameters

if TYPE_CHECKING:
    import torch.LongTensor
    from exllamav2.generator import ExLlamaV2DynamicGenerator, ExLlamaV2Sampler


class ExllamaV2Params(TypedDict, total=False):
    max_tokens: int
    stop_conditions: Optional[List[Union[int, str]]]
    seed: Optional[int]
    gen_settings: "ExLlamaV2Sampler.Settings"
    max_new_tokens: List[int]


class OutlinesExLlamaV2Tokenizer:
    def __init__(self, tokenizer):
        self.exl2_tokenizer = tokenizer
        self.vocabulary = self.exl2_tokenizer.get_piece_to_id_dict()
        self.special_tokens = set(self.exl2_tokenizer.extended_piece_to_id)
        self.eos_token_id = self.exl2_tokenizer.eos_token_id

    def convert_token_to_string(self, token):
        return token

    def decode(self, token_ids: "torch.LongTensor") -> List[str]:
        decoded = self.exl2_tokenizer.decode(
            torch.tensor(token_ids),
            decode_special_tokens=False,
        )
        if isinstance(decoded, str):
            return [decoded]
        return decoded


class ExLlamaV2Model:
    """Represents a `exl2` model."""

    def __init__(
        self,
        generator: "ExLlamaV2DynamicGenerator",
        tokenizer: "OutlinesExLlamaV2Tokenizer",
        max_seq_len: int,
    ):
        self.generator = generator
        self.tokenizer = tokenizer
        self.max_seq_len = max_seq_len

    def prepare_generation_parameters(
        self,
        prompts: Union[str, List[str]],
        generation_parameters: GenerationParameters,
        sampling_parameters: SamplingParameters,
        structure_logits_processor,
        **exllamav2_params: Unpack[ExllamaV2Params],
    ) -> Tuple[ExllamaV2Params, Union[str, List[str]]]:
        """Prepare the generation parameters.

        `exllamav2` uses different default values

        """
        from exllamav2.generator import ExLlamaV2Sampler

        if isinstance(prompts, str):
            prompts = [prompts]
        max_tokens, stop_at, seed = dataclasses.astuple(generation_parameters)

        if max_tokens is None:
            max_tokens = []
            for prompt in prompts:
                ids = self.generator.tokenizer.encode(
                    prompt, encode_special_tokens=True
                )
                prompt_tokens = ids.shape[-1]
                max_tokens.append(self.max_seq_len - prompt_tokens)
            exllamav2_params["max_new_tokens"] = max_tokens
        else:
            exllamav2_params["max_new_tokens"] = [
                max_tokens for _ in range(len(prompts))
            ]

        stop_conditions = [self.generator.tokenizer.eos_token_id]
        if isinstance(generation_parameters.stop_at, str):
            stop_conditions.append(generation_parameters.stop_at)
        elif isinstance(generation_parameters.stop_at, list):
            for stop_at in generation_parameters.stop_at:
                stop_conditions.append(stop_at)
        exllamav2_params["stop_conditions"] = stop_conditions
        exllamav2_params["seed"] = seed

        gen_settings = ExLlamaV2Sampler.Settings()
        if sampling_parameters.temperature is not None:
            gen_settings.temperature = sampling_parameters.temperature
        if sampling_parameters.top_p is not None:
            gen_settings.top_p = sampling_parameters.top_p
        if sampling_parameters.top_k is not None:
            gen_settings.top_k = sampling_parameters.top_k
        gen_settings.logits_processor = structure_logits_processor
        exllamav2_params["gen_settings"] = gen_settings
        if sampling_parameters.num_samples > 1:
            prompts = prompts * sampling_parameters.num_samples
            exllamav2_params["max_new_tokens"] = (
                exllamav2_params["max_new_tokens"] * sampling_parameters.num_samples
            )

        if len(prompts) == 1:
            prompts = prompts[0]

        return exllamav2_params, prompts

    def reformat_output(
        self, output: Union[str, List[str]], sampling_parameters: SamplingParameters
    ):
        """
        The purpose of this function is to reformat the output from exllamav2's output format to outline's output format
        For exllamav2, it mainly accepts only a list or a string(they also do cfg sampling with tuples but we will ignore this for now)
        The exllamav2's logic is
        1. If the prompt is a string, return a string. This is the same as outlines
        2. If a prompt is a list, return a list. This is not the same as outlines output in that if the list is only one element, the string is expected to be outputted.
        3. There is no such thing as num_samples, so the prompts had to be duplicated by num_samples times. Then, we had the function output a list of lists
        """
        if isinstance(output, str):
            return output
        if len(output) == 1:
            return output[0]
        if sampling_parameters.num_samples > 1:
            if len(output) == sampling_parameters.num_samples:
                return output
            assert len(output) % sampling_parameters.num_samples == 0
            num_items_per_sample = len(output) // sampling_parameters.num_samples
            new_output = []
            for i in range(sampling_parameters.num_samples):
                curr_sample = []
                for j in range(num_items_per_sample):
                    curr_sample.append(output[i * num_items_per_sample + j])
                new_output.append(curr_sample)
            return new_output
        return output

    def generate(
        self,
        prompts: Union[str, List[str]],
        generation_parameters: GenerationParameters,
        structure_logits_processor,
        sampling_parameters: SamplingParameters,
        **exllamav2_params: Unpack[ExllamaV2Params],
    ) -> Union[str, List[str]]:
        exllamav2_params, prompts = self.prepare_generation_parameters(
            prompts,
            generation_parameters,
            sampling_parameters,
            structure_logits_processor,
        )
        """
        In exllamav2, it needs the max amount of new tokens generated.
        The reason exllamav2_params["max_new_tokens"] is a list is because in prepare_generation_parameters
        the max amount of tokens that can be generated by the model for each prompt(by encoding with tokenizer) is calculated.
        The minimum is picked because otherwise it might be possible for one of the
        prompts to exceed the max sequence length.
        """
        output = self.generator.generate(
            prompt=prompts,
            gen_settings=exllamav2_params["gen_settings"],
            max_new_tokens=min(exllamav2_params["max_new_tokens"]),
            completion_only=True,
            encode_special_tokens=True,
            stop_conditions=exllamav2_params["stop_conditions"],
            add_bos=False,
            seed=exllamav2_params["seed"],
        )

        return self.reformat_output(output, sampling_parameters)

    def stream(
        self,
        prompts: Union[str, List[str]],
        generation_parameters: GenerationParameters,
        structure_logits_processor,
        sampling_parameters: SamplingParameters,
        **exllamav2_params: Unpack[ExllamaV2Params],
    ) -> Iterator[Union[str, List[str]]]:
        from exllamav2.generator import ExLlamaV2DynamicJob

        exllamav2_params, prompts = self.prepare_generation_parameters(
            prompts,
            generation_parameters,
            sampling_parameters,
            structure_logits_processor,
        )

        order = {}
        if isinstance(prompts, str):
            prompts = [prompts]
        batch_size = len(prompts)
        seed = exllamav2_params["seed"]
        for idx, p in enumerate(prompts):
            input_ids = self.generator.tokenizer.encode(
                p, encode_special_tokens=True, add_bos=False
            )

            job = ExLlamaV2DynamicJob(
                input_ids=input_ids,
                max_new_tokens=exllamav2_params["max_new_tokens"][idx],
                min_new_tokens=0,
                seed=seed,
                stop_conditions=exllamav2_params["stop_conditions"],
                gen_settings=exllamav2_params["gen_settings"],
                token_healing=False,
                decode_special_tokens=False,
            )

            if seed is not None:
                seed += 1

            serial = self.generator.enqueue(job)
            order[serial] = idx

        # Collect outputs until all jobs finish

        next_text = [""] * batch_size

        def token_generator() -> Iterator[str]:
            while self.generator.num_remaining_jobs():
                results = self.generator.iterate()
                for r in results:
                    idx = order[r["serial"]]
                    if r["stage"] == "streaming":
                        text = r.get("text", "")
                        next_text[idx] = text
                    if r["eos"]:
                        next_text[idx] = ""
                yield self.reformat_output(next_text, sampling_parameters)
            return

        return token_generator()


def exl2(
    model_path: str,
    draft_model_path: Optional[str] = None,
    max_seq_len: Optional[int] = None,
    cache_q4: bool = False,
    paged: bool = True,
    max_chunk_size: Optional[int] = None,
) -> ExLlamaV2Model:
    """
    Load an ExLlamaV2 model.

    Parameters
    ----------
    model_path (str)
        Path to the model directory.
    device (str)
        Device to load the model on. Pass in 'cuda' for GPU or 'cpu' for CPU
    max_seq_len (Optional[int], optional)
        Maximum sequence length. Defaults to None.
    scale_pos_emb (Optional[float], optional)
        Scale factor for positional embeddings. Defaults to None.
    scale_alpha_value (Optional[float], optional)
        Scale alpha value. Defaults to None.
    no_flash_attn (Optional[bool], optional)
        Disable flash attention. Defaults to None.
    num_experts_per_token (Optional[int], optional)
        Number of experts per token. Defaults to None.
    cache_q4 (bool, optional)
        Use Q4 cache. Defaults to False.
    tokenizer_kwargs (dict, optional)
        Additional keyword arguments for the tokenizer. Defaults to {}.
    gpu_split (str)
        \"auto\", or VRAM allocation per GPU in GB. Auto will use exllama's autosplit feature
    low_mem (bool, optional)
        Enable VRAM optimizations, potentially trading off speed
    verbose (bool, optional)
        Enable if you want debugging statements

    Returns
    -------
    An `ExLlamaV2Model` instance.

    Raises
    ------
    `ImportError` if the `exllamav2` library is not installed.

    """
    try:
        from exllamav2 import (
            ExLlamaV2,
            ExLlamaV2Cache,
            ExLlamaV2Cache_Q4,
            ExLlamaV2Config,
            ExLlamaV2Tokenizer,
        )
        from exllamav2.generator import ExLlamaV2DynamicGenerator

    except ImportError:
        raise ImportError(
            "The `exllamav2`, `transformers` and `torch` libraries needs to be installed in order to use `exllamav2` models. "
            "Please run `pip install transformers torch git+https://github.com/lapp0/exllamav2@sampler-logits-processor` "
            "Documentation: https://dottxt-ai.github.io/outlines/latest/reference/models/exllamav2/"
        )
    config = ExLlamaV2Config(model_path)
    if max_chunk_size is not None:
        config.max_input_len = max_chunk_size
        config.max_attention_size = max_chunk_size**2

    config.arch_compat_overrides()
    model = ExLlamaV2(config)
    if max_seq_len is None:
        max_seq_len = -1
    if cache_q4:
        cache = ExLlamaV2Cache_Q4(model, max_seq_len=max_seq_len, lazy=True)
    else:
        cache = ExLlamaV2Cache(model, max_seq_len=max_seq_len, lazy=True)
    model.load_autosplit(cache, progress=True)

    print("Loading tokenizer...")
    tokenizer = ExLlamaV2Tokenizer(config)
    max_batch_size = 4 if paged else 1

    draft_model = None
    draft_cache = None
    if draft_model_path is not None:
        draft_config = ExLlamaV2Config(draft_model_path)
        draft_model = ExLlamaV2(draft_config)

        if cache_q4:
            draft_cache = ExLlamaV2Cache_Q4(
                draft_model, max_seq_len=max_seq_len, lazy=True
            )
        else:
            draft_cache = ExLlamaV2Cache(
                draft_model, max_seq_len=max_seq_len, lazy=True
            )

    # Initialize the generator with all default parameters
    generator = ExLlamaV2DynamicGenerator(
        model=model,
        cache=cache,
        draft_model=draft_model,
        draft_cache=draft_cache,
        tokenizer=tokenizer,
        max_batch_size=max_batch_size,
        use_ngram_draft=False,
        max_chunk_size=max_chunk_size,
        paged=paged,
    )
    max_seq_len = cache.max_seq_len

    outlines_tokenizer = OutlinesExLlamaV2Tokenizer(tokenizer)
    outlines_exl2_model = ExLlamaV2Model(generator, outlines_tokenizer, max_seq_len)
    return outlines_exl2_model
