# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
"""Passes for debugging purposes."""

from __future__ import annotations

__all__ = [
    "CheckerPass",
]

from typing import Literal

import onnx  # noqa: TID251

import onnx_ir as ir
from onnx_ir.passes.common import _c_api_utils


class CheckerPass(ir.passes.PassBase):
    """Run onnx checker on the model."""

    @property
    def in_place(self) -> Literal[True]:
        """This pass does not create a new model."""
        return True

    @property
    def changes_input(self) -> Literal[False]:
        """This pass does not change the input model."""
        return False

    def __init__(
        self,
        full_check: bool = False,
        skip_opset_compatibility_check: bool = False,
        check_custom_domain: bool = False,
    ):
        super().__init__()
        self.full_check = full_check
        self.skip_opset_compatibility_check = skip_opset_compatibility_check
        self.check_custom_domain = check_custom_domain

    def call(self, model: ir.Model) -> ir.passes.PassResult:
        """Run the onnx checker on the model."""

        def _partial_check_model(proto: onnx.ModelProto) -> None:
            """Partial function to check the model."""
            onnx.checker.check_model(
                proto,
                full_check=self.full_check,
                skip_opset_compatibility_check=self.skip_opset_compatibility_check,
                check_custom_domain=self.check_custom_domain,
            )

        _c_api_utils.call_onnx_api(func=_partial_check_model, model=model)
        # The model is not modified
        return ir.passes.PassResult(model, False)
