# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=missing-function-docstring

import torch
from lhotse.cut import Cut, MixedCut

from nemo.collections.common.data.lhotse.text_adapters import NeMoSFTExample, SourceTargetTextExample
from nemo.collections.common.data.prompt_fn import registered_prompt_format_fn
from nemo.collections.common.prompts.formatter import BOS_SLOT, EOS_SLOT, Modality, PromptFormatter


class Llama2PromptFormatter(PromptFormatter):
    """
    This template has been validated to provide identical tokenized results to the official code
    in https://github.com/meta-llama/llama/blob/main/llama/generation.py
    """

    NAME = "llama2"
    OUTPUT_ROLE = "assistant"
    TEMPLATE = {
        "system_and_user": {
            "template": f"{BOS_SLOT}[INST] <<SYS>>\n|system|\n<</SYS>>\n\n|message| [/INST]",
            "slots": {
                "system": Modality.Text,
                "message": Modality.Text,
            },
        },
        "user": {
            "template": f"{BOS_SLOT}[INST] |message| [/INST]",
            "slots": {
                "message": Modality.Text,
            },
        },
        OUTPUT_ROLE: {
            "template": f"|message| {EOS_SLOT}",
            "slots": {
                "message": Modality.Text,
            },
        },
    }


@registered_prompt_format_fn(Cut, Llama2PromptFormatter)
def llama2(cut: Cut, prompt: Llama2PromptFormatter) -> dict[str, torch.Tensor]:
    if isinstance(cut, MixedCut):
        cut = cut.first_non_padding_cut
    if cut.has_custom("context"):
        context = cut.context
    elif cut.has_custom("question"):
        context = cut.question
    else:
        context = cut.default_context

    turns = []
    if cut.has_custom("system_prompt"):
        turns.append({"role": "system_and_user", "slots": {"system": cut.system_prompt, "message": context}})
    else:
        turns.append({"role": "user", "slots": {"message": context}})
    if (answer := cut.supervisions[0].text) is not None:
        turns.append({"role": "assistant", "slots": {"message": answer}})
    return prompt.encode_dialog(turns)


@registered_prompt_format_fn(SourceTargetTextExample, Llama2PromptFormatter)
def llama2_src_tgt_text_example(example: SourceTargetTextExample, prompt: Llama2PromptFormatter):
    if example.question is not None:
        user_turn = {
            "role": "system_and_user",
            "slots": {"system": example.question.text, "message": example.source.text},
        }
    else:
        user_turn = {
            "role": "user",
            "slots": {"message": example.source.text},
        }
    return prompt.encode_dialog(
        [
            user_turn,
            {"role": prompt.OUTPUT_ROLE, "slots": {"message": example.target.text}},
        ]
    )


@registered_prompt_format_fn(NeMoSFTExample, Llama2PromptFormatter)
def llama2_sft_text_example(example: NeMoSFTExample, prompt: Llama2PromptFormatter):
    first_user_turn = example.data["conversations"][0]["value"]
    if "system" in example.data and example.data["system"]:
        first_turn = {
            "role": "system_and_user",
            "slots": {"system": example.data["system"], "message": first_user_turn},
        }
    else:
        first_turn = {
            "role": "user",
            "slots": {"message": first_user_turn},
        }
    return prompt.encode_dialog(
        [first_turn]
        + [
            {"role": "user" if turn["from"] == "User" else prompt.OUTPUT_ROLE, "slots": {"message": turn["value"]}}
            for turn in example.data["conversations"][1:]
        ]
    )


LLAMA3_BOS = "<|begin_of_text|>"
LLAMA3_HEADER_BEGIN = "<|start_header_id|>"
LLAMA3_HEADER_END = "<|end_header_id|>"
LLAMA3_END_OF_TURN = "<|eot_id|>"
LLAMA3_NL = "\n\n"


class Llama3PromptFormatter(PromptFormatter):
    """
    Implemented following the code at:
     https://github.com/meta-llama/llama3/blob/main/llama/test_tokenizer.py#L56
    """

    NAME = "llama3"
    OUTPUT_ROLE = "assistant"
    INFERENCE_PREFIX = f"{LLAMA3_HEADER_BEGIN}assistant{LLAMA3_HEADER_END}{LLAMA3_NL}"
    TEMPLATE = {
        "preamble": {
            "template": LLAMA3_BOS,
        },
        "system": {
            "template": f"{LLAMA3_HEADER_BEGIN}system{LLAMA3_HEADER_END}{LLAMA3_NL}|message|{LLAMA3_END_OF_TURN}",
            "slots": {
                "message": Modality.Text,
            },
        },
        "user": {
            "template": f"{LLAMA3_HEADER_BEGIN}user{LLAMA3_HEADER_END}{LLAMA3_NL}|message|{LLAMA3_END_OF_TURN}",
            "slots": {
                "message": Modality.Text,
            },
        },
        OUTPUT_ROLE: {
            "template": f"{INFERENCE_PREFIX}|message|{LLAMA3_END_OF_TURN}",
            "slots": {
                "message": Modality.Text,
            },
        },
    }


@registered_prompt_format_fn(Cut, Llama3PromptFormatter)
def llama3(cut: Cut, prompt: Llama3PromptFormatter) -> dict[str, torch.Tensor]:
    if isinstance(cut, MixedCut):
        cut = cut.first_non_padding_cut
    if cut.has_custom("context"):
        context = cut.context
    elif cut.has_custom("question"):
        context = cut.question
    else:
        context = cut.default_context

    turns = []
    if cut.has_custom("system_prompt"):
        turns.append({"role": "system", "slots": {"message": cut.system_prompt}})
    turns.append({"role": "user", "slots": {"message": context}})
    if (answer := cut.supervisions[0].text) is not None:
        turns.append({"role": "assistant", "slots": {"message": answer}})
    return prompt.encode_dialog(turns)
