# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=missing-function-docstring,missing-class-docstring

from abc import ABC
from functools import lru_cache
from typing import Any, Type

import torch

from nemo.collections.common.tokenizers import AggregateTokenizer, TokenizerSpec

PREAMBLE_ROLE = "preamble"

# Slots used to define when special tokens bos/eos should be inserted.
# These are special in the sense of how sentencepiece defines special tokens:
# They have to be specially inserted into the token sequence, and if they appear in the tokenized string,
# SPE wouldn't use the special token ids but rather tokenize them as if they were normal strings.
# We mimic SPE's behavior if these special slots are present in the template definition.
# To achieve that, insert |bos| / |eos| at the beginning/end of template.
# E.g., inserting only bos in llama2 user role: "template": "|bos|[INST] |message| [\INST]"
BOS_SLOT = "|bos|"
EOS_SLOT = "|eos|"


class BaseModalityType:
    @staticmethod
    def matches(value: Any) -> bool:
        raise NotImplementedError

    def __repr__(self):
        return f"Modality.{self.__class__.__name__}()"


class Text(BaseModalityType):
    """Modality for text values."""

    @staticmethod
    def matches(value: str) -> bool:
        return isinstance(value, str)


class TextLiteral(BaseModalityType):
    def __init__(self, *items):
        self.allowed_values = items

    def matches(self, value: str) -> bool:
        return isinstance(value, str) and value in self.allowed_values

    def __repr__(self):
        return f"Modality.{self.__class__.__name__}(allowed_values={self.allowed_values})"


class Modality:
    """
    Modalities supported as PromptFormatter slot values.
    """

    Text = Text
    TextLiteral = TextLiteral


class PromptFormatter(ABC):
    """
    :class:`~nemo.collections.common.prompts.formatter.PromptFormatter` is intended to simplify
    working with various prompt format templates and encoding them into token ID tensors.

    It assumes a dialog-like structure, which is a list of turns, with each turn assigned to a role.
    Sub-classes of PromptFormatter define turn templates for each role under TEMPLATE class attribute.
    Each template may define some constant parts (e.g. begin-of-turn or end-of-turn tokens, whitespaces, etc.)
    and variable parts which we call "slots", that will be provided by the user during training or inference.

    A role is typically "user" and "assistant", and some popular models also use a "system" role.
    Other roles may be defined as well. We expect the role corresponding to the model's responses
    will be registered under class attribute called OUTPUT_ROLE.
    We reserve a special "preamble" role with no slots that will be inserted at the beginning of
    the formatted prompt, if "preamble" is present in TEMPLATE.

    A turn is a dict with keys "role" and "slots", where "slots" are a dict that maps slot names
    to values that should be filled in the template.
    For example, a user role template may be ``"Question: |message|"`` and corresponding ``slots`` would then be
    ``{"message": "What time is it?"}``.

    There is a special slot called ``|prompt_language|`` that's used to select the sub-tokenizer in
    :class:`~nemo.collections.common.tokenizers.aggregate_tokenizer.AggregateTokenizer`.
    It's only used when the tokenizer is aggregate; otherwise it's discarded.

    PromptFormatter supports constructing prompts for training (complete context and answers)
    and for inference (context-only).
    Training/inference is determined automatically; if the last role in a dialog is the OUTPUT_ROLE,
    that's an 'asked-and-answered' scenario, so we assume it's inteded for training.
    We'll create a dict with tokenized results available under the following keys:

    * ``context_ids`` (all turns minus last one),
    * ``answer_ids`` (last turn)
    * ``input_ids`` (previous two values concatenated)
    * ``mask`` (boolean mask tensor of the same lenth as ``input_ids`` that's set to True on OUTPUT_ROLE turns)

    Typically, the user will use the ``encode_dialog`` method providing a list of turns to it.
    Example showing how to construct model inputs/outputs for training::

        >>> formatter = PromptFormatter(tokenizer)
        ... encoded_for_training = formatter.encode_dialog(
        ...     turns=[
        ...         {"role": "user", "slots": {"message": "What time is it?"}},
        ...         {"role": "assistant", "slots": {"message": "Ten o'clock."}},
        ...         {"role": "user", "slots": {"message": "PM or AM?"}},
        ...         {"role": "assistant", "slots": {"message": "AM, naturally! It's bright outside"}},
        ...     ]
        ... )

    Another example that shows how to use the same method to generate prompts for inference::


        >>> formatter = PromptFormatter(tokenizer)
        ... encoded_for_inference = formatter.encode_dialog(
        ...     turns=[
        ...         {"role": "user", "slots": {"message": "What time is it?"}},
        ...         {"role": "assistant", "slots": {"message": "Ten o'clock."}},
        ...         {"role": "user", "slots": {"message": "PM or AM?"}},
        ...     ]
        ... )

    """

    # Used to support AggregateTokenizer; this key selects the right sub-tokenizer for each turn.
    PROMPT_LANGUAGE_SLOT = "prompt_language"

    # Subclasses will be registered under this name, to be used via PromptFormatter.resolve(name).
    NAME = None

    # Template is a dict that maps:
    # * from a role name string (system/user/assistant/etc)
    # * to a dict with keys
    #   * "template" that has a string value (the prompt template)
    #   * "slots" that has a value of dict[str, Modality]
    #       * keys of slots are the names of formattable slots in the prompt template
    #       * values of slots are :class:`Modality` objects that can be used to check
    #           whether a specific value conforms to a given modality requirements
    #           (e.g., Modality.Text may expect string objects).
    # Template is intended to be defined by the child classes.
    TEMPLATE = None

    # Turns under this role indicate responses by the model; if the last turn in
    # PromptFormatter.encode_dialog() ends with this role, it indicates a training example.
    OUTPUT_ROLE = None

    # When specified, we will append this prefix at the end of the prompt at inference time.
    # We detect inference time by the fact that the last turn is not from OUTPUT_ROLE.
    INFERENCE_PREFIX = None

    # When set to true, we will insert BOS/EOS symbol at the very beginning/end of the dialog
    # (i.e., not before/after every turn).
    # This is intended specifically for LLMs that use sentencepiece tokenizers with BOS/EOS
    # that don't normally exist in the tokenizer's vocab (i.e., no string input generates them
    # and you must insert them programmatically);
    # see: https://github.com/google/sentencepiece/issues/102#issuecomment-397150427
    INSERT_BOS = False
    INSERT_EOS = False

    # Internal reserved field.
    _REGISTERED_FORMATTERS = {}

    def __init__(self, tokenizer: TokenizerSpec, defaults: list[dict] | None = None) -> None:
        self.tokenizer = tokenizer
        self._defaults = defaults if defaults is not None else []
        self._validate_defaults()

    def __init_subclass__(cls, **kwargs) -> None:
        ERR = "PromptFormatter subclass definition error:"
        if cls.__name__ not in cls._REGISTERED_FORMATTERS:
            for attr in ("NAME", "TEMPLATE", "OUTPUT_ROLE"):
                assert (
                    getattr(cls, attr, None) is not None
                ), f"{ERR} PromptFormatter subclass {cls} did not define a class attribute {attr}"
            assert cls.NAME not in cls._REGISTERED_FORMATTERS, (
                f"Cannot register {cls.__name__} under {cls.NAME}: another prompt formatter of type "
                f"{cls._REGISTERED_FORMATTERS[cls.NAME]} has already been registered under this name."
            )
            cls._REGISTERED_FORMATTERS[cls.NAME] = cls
        if "preamble" in cls.TEMPLATE:
            assert (
                len(cls.TEMPLATE["preamble"].get("slots", [])) == 0
            ), f"{ERR} Slots are not allowed for preamble template, but we found: '{cls.TEMPLATE['preamble']}'"
        for role in cls.get_roles():
            template = cls.get_template(role)
            for slot in cls.get_slots(role):
                assert (
                    _mangled(slot) in template
                ), f"{ERR} Slot '{slot}' not found in template '{template}' for role '{role}'"
        super().__init_subclass__(**kwargs)

    @classmethod
    def resolve(cls, name: str) -> Type["PromptFormatter"]:
        if name not in cls._REGISTERED_FORMATTERS:
            raise RuntimeError(
                f"Unknown prompt formatter: '{name}' (known formats: {', '.join(cls._REGISTERED_FORMATTERS.keys())})"
            )
        return cls._REGISTERED_FORMATTERS[name]

    @classmethod
    @lru_cache(1)
    def get_roles(cls) -> list[str]:
        return list(cls.TEMPLATE.keys())

    @classmethod
    def get_slots(cls, role: str) -> dict[str, Modality]:
        # returns a copy to avoid accidential mutation of a global object by the user
        return cls.TEMPLATE[role].get("slots", {}).copy()

    @classmethod
    def get_template(cls, role: str) -> str:
        return cls.TEMPLATE[role]["template"]

    def get_default_dialog_slots(self) -> list[dict]:
        """
        Returns a list of dialog turns that can be used as a skeleton to fill with actual slot values.
        If ``PromptFormatter`` was initialized with ``defaults`` argument, this method will return the
        defaults. Otherwise, every slot is pre-filled with ``None``.
        """

        def _get_default_for_role(role: str) -> dict:
            for turn in self._defaults:
                if turn["role"] == role:
                    return turn
            return {}

        return [
            {
                "role": role,
                "slots": {
                    slot: _get_default_for_role(role).get("slots", {}).get(slot) for slot in self.get_slots(role)
                },
            }
            for role in self.get_roles()
            if role != self.OUTPUT_ROLE
        ]

    def encode_turn(
        self, prompt_template: str, expected_slots: dict[str, Modality], slot_values: dict[str, Any]
    ) -> list[int]:
        prompt = prompt_template

        # normal case
        for slot in expected_slots:
            # For the final substitution of 'slot' in the template we have to mangle it to '|slot|' anyway,
            # but 'slot' form enables to use valid python identifiers as **kwargs
            # for passing slots around in user functions.
            value = slot_values.get(slot)
            assert value is not None, f"Missing required {slot=} in {slot_values=} for {prompt_template=}"
            prompt = prompt.replace(_mangled(slot), value)
        return self._apply_tokenizer(prompt, lang=slot_values.get(self.PROMPT_LANGUAGE_SLOT))

    def encode_dialog(self, turns: list[dict]) -> dict[str, torch.Tensor]:
        roles = self.get_roles()
        assert len(turns) > 0, "Empty dialog is not supported."
        for turn in turns:
            assert "role" in turn, f"A turn must have have a 'role' key. We received {turn=}"
            assert turn["role"] in roles, f"Found turn with {turn['role']=}, but available roles are {roles}"

        turn_tokens = []
        turn_token_counts = []
        turn_mask_values = []

        if self.INSERT_BOS:
            turn_tokens.append(self.tokenizer.bos)
            turn_token_counts.append(1)
            turn_mask_values.append(False)

        if "preamble" in self.TEMPLATE:
            preamble_turns = [idx for idx, t in enumerate(turns) if t["role"] == "preamble"]
            if not preamble_turns:
                turns = [{"role": "preamble", **self.TEMPLATE["preamble"]}] + turns
            else:
                assert (
                    len(preamble_turns) == 1 and preamble_turns[0] == 0
                ), f"Preamble can only be presented at turn 0 but we found preamble turns at indexes {preamble_turns}."

        is_inference = turns[-1]["role"] != self.OUTPUT_ROLE
        for turn in turns:
            role = turn["role"]
            expected_slots = self.get_slots(role)
            if "content" in turn and len(expected_slots) == 1:
                # User is leveraging the "standard" API prompting LLM; we'll map "content" value
                # to whatever is the name of the slot, when there's only one slot.
                slot_values = {k: turn["content"] for k in expected_slots.keys()}  # 1-item dict
            else:
                slot_values = turn.get("slots", {})
                if expected_slots:
                    assert slot_values, (
                        f"A turn for role {role} must have have a non-empty value under 'slots' key. "
                        f"We received {turn=}"
                    )
                    self._validate_slot_values(expected_slots, slot_values)
            template = self.get_template(role)
            tokens = self.encode_turn(template, expected_slots, slot_values)
            turn_tokens.extend(tokens)
            turn_token_counts.append(len(tokens))
            turn_mask_values.append(role == self.OUTPUT_ROLE)

        if is_inference and self.INFERENCE_PREFIX is not None:
            inference_prefix = self._apply_tokenizer(self.INFERENCE_PREFIX)
            turn_tokens.extend(inference_prefix)
            turn_token_counts.append(len(inference_prefix))
            turn_mask_values.append(False)  # not a training example

        # Insert EOS only when the last turn comes from the OUTPUT_ROLE.
        if self.INSERT_EOS and not is_inference:
            turn_tokens.append(self.tokenizer.eos)
            turn_token_counts[-1] += 1
            turn_mask_values.append(True)

        ans = {"input_ids": torch.tensor(turn_tokens, dtype=torch.long)}
        if turn_mask_values[-1]:
            # The last turn comes from OUTPUT_ROLE, i.e. it's a response from the system.
            # This indicates it's a training example for which we provide context/answer/mask.
            ans["context_ids"] = ans["input_ids"][: -turn_token_counts[-1]]
            ans["answer_ids"] = ans["input_ids"][-turn_token_counts[-1] :]
            ans["mask"] = torch.tensor(
                [
                    turn_mask_values[turn_idx]
                    for turn_idx, turn_len in enumerate(turn_token_counts)
                    for _ in range(turn_len)
                ],
                dtype=torch.bool,
            )
        else:
            ans["context_ids"] = ans["input_ids"]  # context == input for inference
        return ans

    def _apply_tokenizer(self, text: str, lang: str | None = None) -> list[int]:
        # Check if the tokenizer is aggregate and perform extra checks.
        is_agg = isinstance(self.tokenizer, AggregateTokenizer)
        if is_agg:
            assert lang is not None, (
                f"Missing key '{self.PROMPT_LANGUAGE_SLOT}' in slot_values -- cannot resolve "
                f"the correct sub-tokenizer in the aggregate tokenizer."
            )

        # Strip bos/eos if present and remember to apply them later.
        has_bos = text.startswith(BOS_SLOT)
        has_eos = text.endswith(EOS_SLOT)
        if has_bos:
            text = text[len(BOS_SLOT) :]
        if has_eos:
            text = text[: -len(EOS_SLOT)]

        # Tokenize, selecting the right API depending on aggregate/normal tokenizer.
        if is_agg:
            tokens = self.tokenizer.text_to_ids(text, lang)
        else:
            tokens = self.tokenizer.text_to_ids(text)

        # Lazily look up bos/eos and apply them. Lazy has the advantage that if a tokenizer
        # doesn't define bos/eos and the prompt format does not request them, everything just works.
        if has_eos:
            eos_id = self.tokenizer.get_eos(lang) if is_agg else self.tokenizer.eos
            tokens.append(eos_id)
        if has_bos:
            bos_id = self.tokenizer.get_bos(lang) if is_agg else self.tokenizer.bos
            tokens = [bos_id] + tokens

        return tokens

    def _validate_slot_values(self, expected: dict[str, Modality], received: dict[str, Any]) -> None:
        missing = set(expected) - set(received)
        assert not missing, f"The following slot values were not provided: {missing}"
        for slot in expected:
            expected_modality = expected[slot]
            value = received[slot]
            assert expected_modality.matches(
                value
            ), f"{slot=} received {value=} which does not match modality {expected_modality}"

    def _validate_defaults(self):
        if not self._defaults:
            return

        err = "Error in default prompt definition:"
        assert isinstance(self._defaults, list)
        for turn in self._defaults:
            assert isinstance(turn, dict)
            assert "role" in turn, f"{err} Missing required 'role' key. We received {turn=}"
            role = turn["role"]
            assert role in self.get_roles(), (
                f"{err} Invalid {role=} in {turn=} - " f"supported roles are: {self.get_roles()}."
            )
            if expected_slots := self.get_slots(role):
                assert "slots" in turn, (
                    f"{err} Missing required 'slots' key in {turn=} - "
                    f"we expected the following slots to be provided: {expected_slots}."
                )
                for slot in turn["slots"]:
                    assert slot in expected_slots, (
                        f"{err} Invalid {slot=} in {turn=}. "
                        f"The following slots are supported for {role=}: {expected_slots}"
                    )


def _mangled(slot: str) -> str:
    if not (slot[0] == "|" and slot[-1] == "|"):
        return f"|{slot}|"
    return slot


def _unmangled(slot: str) -> str:
    if slot[0] == "|" and slot[-1] == "|":
        return slot[1:-1]
    return slot
