# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


use_pytriton = True
try:
    from pytriton.model_config import ModelConfig
    from pytriton.triton import Triton, TritonConfig
except Exception:
    use_pytriton = False

from nemo.deploy.deploy_base import DeployBase


class DeployPyTriton(DeployBase):
    """
    Deploys any models to Triton Inference Server that implements ITritonDeployable interface in nemo.deploy.

    Example:
        from nemo.deploy import DeployPyTriton, NemoQueryLLM
        from nemo.export.tensorrt_llm import TensorRTLLM

        trt_llm_exporter = TensorRTLLM(model_dir="/path/for/model/files")
        trt_llm_exporter.export(
            nemo_checkpoint_path="/path/for/nemo/checkpoint",
            model_type="llama",
            tensor_parallelism_size=1,
        )

        nm = DeployPyTriton(model=trt_llm_exporter, triton_model_name="model_name", http_port=8000)
        nm.deploy()
        nm.run()
        nq = NemoQueryLLM(url="localhost", model_name="model_name")

        prompts = ["hello, testing GPT inference", "another GPT inference test?"]
        output = nq.query_llm(prompts=prompts, max_output_len=100)
        print("prompts: ", prompts)
        print("")
        print("output: ", output)
        print("")

        prompts = ["Give me some info about Paris", "Do you think Londan is a good city to visit?", "What do you think about Rome?"]
        output = nq.query_llm(prompts=prompts, max_output_len=250)
        print("prompts: ", prompts)
        print("")
        print("output: ", output)
        print("")

    """

    def __init__(
        self,
        triton_model_name: str,
        triton_model_version: int = 1,
        checkpoint_path: str = None,
        model=None,
        max_batch_size: int = 128,
        http_port: int = 8000,
        grpc_port: int = 8001,
        address="0.0.0.0",
        allow_grpc=True,
        allow_http=True,
        streaming=False,
        pytriton_log_verbose=0,
    ):
        """
        A nemo checkpoint or model is expected for serving on Triton Inference Server.

        Args:
            triton_model_name (str): Name for the service
            triton_model_version(int): Version for the service
            checkpoint_path (str): path of the nemo file
            model (ITritonDeployable): A model that implements the ITritonDeployable from nemo.deploy import ITritonDeployable
            max_batch_size (int): max batch size
            port (int) : port for the Triton server
            address (str): http address for Triton server to bind.
        """

        super().__init__(
            triton_model_name=triton_model_name,
            triton_model_version=triton_model_version,
            checkpoint_path=checkpoint_path,
            model=model,
            max_batch_size=max_batch_size,
            http_port=http_port,
            grpc_port=grpc_port,
            address=address,
            allow_grpc=allow_grpc,
            allow_http=allow_http,
            streaming=streaming,
            pytriton_log_verbose=pytriton_log_verbose,
        )

    def deploy(self):
        """
        Deploys any models to Triton Inference Server.
        """

        self._init_nemo_model()

        try:
            if self.streaming:
                # TODO: can't set allow_http=True due to a bug in pytriton, will fix in latest pytriton
                triton_config = TritonConfig(
                    log_verbose=self.pytriton_log_verbose,
                    allow_grpc=self.allow_grpc,
                    allow_http=self.allow_http,
                    grpc_address=self.address,
                )
                self.triton = Triton(config=triton_config)
                self.triton.bind(
                    model_name=self.triton_model_name,
                    model_version=self.triton_model_version,
                    infer_func=self.model.triton_infer_fn_streaming,
                    inputs=self.model.get_triton_input,
                    outputs=self.model.get_triton_output,
                    config=ModelConfig(decoupled=True),
                )
            else:
                triton_config = TritonConfig(
                    http_address=self.address,
                    http_port=self.http_port,
                    grpc_address=self.address,
                    grpc_port=self.grpc_port,
                    allow_grpc=self.allow_grpc,
                    allow_http=self.allow_http,
                )
                self.triton = Triton(config=triton_config)
                self.triton.bind(
                    model_name=self.triton_model_name,
                    model_version=self.triton_model_version,
                    infer_func=self.model.triton_infer_fn,
                    inputs=self.model.get_triton_input,
                    outputs=self.model.get_triton_output,
                    config=ModelConfig(max_batch_size=self.max_batch_size),
                )
        except Exception as e:
            self.triton = None
            print(e)

    def serve(self):
        """
        Starts serving the model and waits for the requests
        """

        if self.triton is None:
            raise Exception("deploy should be called first.")

        try:
            self.triton.serve()
        except Exception as e:
            self.triton = None
            print(e)

    def run(self):
        """
        Starts serving the model asynchronously.
        """

        if self.triton is None:
            raise Exception("deploy should be called first.")

        self.triton.run()

    def stop(self):
        """
        Stops serving the model.
        """

        if self.triton is None:
            raise Exception("deploy should be called first.")

        self.triton.stop()
