/*
 * SPDX-License-Identifier: Apache-2.0
 */

#pragma once

#include <cstring>
#include <functional>
#include <initializer_list>
#include <iostream>
#include <map>
#include <memory>
#include <mutex>
#include <ostream>
#include <string>
#include <string_view>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>

#include "onnx/common/common.h"
#include "onnx/common/constants.h"
#include "onnx/defs/data_type_utils.h"
#include "onnx/defs/shape_inference.h"

namespace ONNX_NAMESPACE {

struct FunctionBodyBuildContext {
  virtual const AttributeProto* getAttribute(const std::string& name) const = 0;
  virtual bool hasInput(int inputIndex) const = 0;
  virtual bool hasOutput(int inputIndex) const = 0;
  // getInputType(i) should return null for missing optional inputs, or if
  // type-inference could not infer the input-type (erroneous model).
  virtual const TypeProto* getInputType(int inputIndex) const = 0;
  virtual ~FunctionBodyBuildContext() = default;
};

struct FunctionBodyBuildContextImpl : public FunctionBodyBuildContext {
  // Input_types: use a default TypeProto for missing types. We use a different convention
  // here (from FunctionBodyBuildContext) to simplify python interoperability.
  // The default value for input_types is included only for backward compatibility.
  // It can be used for functions that do not depend on the type-context, but
  // will not be sufficient for functions that do use the type-context.
  explicit FunctionBodyBuildContextImpl(const NodeProto& node_proto, const std::vector<TypeProto>& input_types = {})
      : node_proto_(node_proto), input_types_(input_types) {
    for (auto& attr : node_proto.attribute()) {
      attributesByName_[attr.name()] = &attr;
    }
  }

  const AttributeProto* getAttribute(const std::string& name) const override {
    auto iter = attributesByName_.find(name);
    if (iter == attributesByName_.end()) {
      return nullptr;
    } else {
      return iter->second;
    }
  }

  bool hasInput(int inputIndex) const override {
    if (inputIndex >= node_proto_.input_size())
      return false;
    return !node_proto_.input(inputIndex).empty();
  }

  bool hasOutput(int inputIndex) const override {
    if (inputIndex >= node_proto_.output_size())
      return false;
    return !node_proto_.output(inputIndex).empty();
  }

  const TypeProto* getInputType(int inputIndex) const override {
    if (inputIndex < 0)
      return nullptr;
    size_t j = static_cast<size_t>(inputIndex);
    if (j >= input_types_.size())
      return nullptr;
    // Convert default value (no variant set) into null.
    if (input_types_[j].value_case() == TypeProto::ValueCase::VALUE_NOT_SET)
      return nullptr;
    return &input_types_[j];
  }

  std::unordered_map<std::string, const AttributeProto*> attributesByName_;

  NodeProto node_proto_;
  std::vector<TypeProto> input_types_;
};

using FunctionBodyQueryFunction = std::function<bool(FunctionBodyBuildContext&)>;

class OpSchema;
using ContextDependentFunctionBodyBuilder =
    std::function<bool(const FunctionBodyBuildContext&, const OpSchema&, FunctionProto&)>;

class SchemaError final : public std::runtime_error {
 public:
  using std::runtime_error::runtime_error;

  explicit SchemaError(const std::string& message) : std::runtime_error(message) {}

  ONNX_API const char* what() const noexcept override {
    if (!expanded_message_.empty()) {
      return expanded_message_.c_str();
    }
    return std::runtime_error::what();
  }

  ONNX_API void AppendContext(const std::string& context) {
    expanded_message_ = ONNX_NAMESPACE::MakeString(std::runtime_error::what(), "\n\n==> Context: ", context);
  }

 private:
  std::string expanded_message_;
};

#define fail_schema(...) ONNX_THROW_EX(ONNX_NAMESPACE::SchemaError(ONNX_NAMESPACE::MakeString(__VA_ARGS__)));

using OperatorSetVersion = int;

using DataTypeSet = std::unordered_set<DataType>;

// Type constraint map. Key is type string. Value is data type set and
// description.
using TypeConstraintMap = std::unordered_map<std::string, std::pair<DataTypeSet, std::string>>;

/**
 * @brief A class to record the schema of an op.
 *
 * OpSchema records the common interface of an op specified by its name.
 *
 * To register an OpSchema, one can use the macro ONNX_OPERATOR_SCHEMA(name) and
 * then append the various functions in the class. For example, for an op
 * that takes in two inputs, one output, and the first input and output
 * could be in-place, can be written as
 *
 *     ONNX_OPERATOR_SCHEMA(name)
 *         .NumInputs(2).NumOutputs(1).AllowConsumed({{0, 0}});
 *
 * To manufacture methods that may be used to register an OpSchema
 * non-statically, the following may be used:
 *
 *     ONNX_OPERATOR_SET_SCHEMA(name, version, OpSchema()
 *         .NumInputs(2).NumOutputs(1).AllowConsumed({{0, 0}}));
 */
class OpSchema final {
 public:
  static constexpr int kUninitializedSinceVersion = -1;
  // Formal parameter options.
  enum FormalParameterOption : uint8_t {
    // The formal parameter is single and not optional.
    // Number of supplied actual parameters must be 1.
    Single = 0,
    // The formal parameter is single and optional.
    // Number of supplied actual parameters may be 0 or 1.
    Optional = 1,
    // The formal parameter is variadic.
    // Number of supplied actual parameters must be N or more, where
    // the minimum value N is indicated separately (default value 1).
    Variadic = 2,
  };
  enum DifferentiationCategory : uint8_t {
    // Whether this formal parameter is differentiable or not cannot
    // be statically determined. It also covers variadic formal
    // parameters which contain both of differentiable and
    // non-differentiable variables.
    Unknown = 0,
    // This formal parameter is differentiable. That is, this formal
    // parameter can be differentiable input of Gradient operator.
    Differentiable = 1,
    // This formal parameter is not differentiable. That is, this formal
    // parameter can not be differentiable input of Gradient operator.
    NonDifferentiable = 2
  };
  enum class NodeDeterminism : uint8_t {
    Unknown = 0,
    NonDeterministic = 1,
    Deterministic = 2,
  };

  // Formal parameter representation, including input/output name, typeStr,
  // description, and type constraints.
  class FormalParameter final {
   public:
    // Constructor.
    FormalParameter() = default;

    explicit FormalParameter(
        std::string name,
        DataTypeSet allowed_type_set,
        std::string type_str,
        std::string description,
        FormalParameterOption param_option = Single,
        bool is_homogeneous = true,
        int min_arity = 1,
        DifferentiationCategory differentiation_category = Unknown)
        : name_(std::move(name)),
          type_set_(std::move(allowed_type_set)),
          type_str_(std::move(type_str)),
#ifndef __ONNX_NO_DOC_STRINGS
          description_(std::move(description)),
#endif
          param_option_(param_option),
          is_homogeneous_(is_homogeneous),
          min_arity_(min_arity),
          differentiation_category_(differentiation_category) {
#ifdef __ONNX_NO_DOC_STRINGS
      ONNX_UNUSED_PARAMETER(description);
#endif
    }

    explicit FormalParameter(
        std::string name,
        std::string description,
        std::string type_str,
        FormalParameterOption param_option = Single,
        bool is_homogeneous = true,
        int min_arity = 1,
        DifferentiationCategory differentiation_category = Unknown)
        : name_(std::move(name)),
          type_str_(std::move(type_str)),
#ifndef __ONNX_NO_DOC_STRINGS
          description_(std::move(description)),
#endif
          param_option_(param_option),
          is_homogeneous_(is_homogeneous),
          min_arity_(min_arity),
          differentiation_category_(differentiation_category) {
#ifdef __ONNX_NO_DOC_STRINGS
      ONNX_UNUSED_PARAMETER(description);
#endif
    }

    // Get formal parameter name.
    ONNX_API const std::string& GetName() const;

    // Get allowed data types.
    ONNX_API const DataTypeSet& GetTypes() const;

    // Get formal parameter type string.
    ONNX_API const std::string& GetTypeStr() const;

    // Get formal parameter description.
    ONNX_API const std::string& GetDescription() const;

    // Get the parameter option, it could be Single, Optional or Variadic.
    ONNX_API FormalParameterOption GetOption() const;

    // Get whether a variadic parameter requires all to be of same type
    ONNX_API bool GetIsHomogeneous() const;

    // Get minimum arity. Applicable only in the Variadic case.
    ONNX_API int GetMinArity() const;

    // Get the differentiation property of this formal parameter.
    ONNX_API DifferentiationCategory GetDifferentiationCategory() const;

   private:
    friend class OpSchema;

    DataTypeSet& MutableTypes();

    // Formal parameter name.
    std::string name_;

    // A set of data types supported for <*this> formal parameter.
    // It should contain at least one element if this formal parameter is good.
    DataTypeSet type_set_;

    // The <parameter type> string specified when registering an op.
    // It could be a supported data type or a type constraint key, which
    // maps to a set of supported data types.
    std::string type_str_;

    // Formal parameter description.
    std::string description_;

    // Formal parameter option.
    FormalParameterOption param_option_{};

    // For variadic parameters, a flag indicating if all parameters must be of
    // same type
    bool is_homogeneous_{};

    // Minimum number of parameters expected. Applicable only for Variadic.
    int min_arity_{};

    // True if this parameter can be an differentiable inputs of Gradient.
    // Otherwise, using this parameter as an differentiable inputs of Gradient
    // is prohibited.
    DifferentiationCategory differentiation_category_{};
  };

  enum class SupportType : uint8_t {
    COMMON, // Supported by all frameworks that support this IR.
    EXPERIMENTAL, // This OP is experimental and can be changed or removed in
                  // the future.
  };

  OpSchema() : OpSchema("unknown", "unknown", 0) {}
  OpSchema(std::string name, std::string file, int line)
      : name_(std::move(name)), file_(std::move(file)), line_(line), support_(SupportType::COMMON) {}

  /**
   * @brief Returns the file that the op schema is registered from.
   */
  ONNX_API const std::string& file() const {
    return file_;
  }

  /**
   * @brief Returns the line in file that the op schema is registered from.
   */
  ONNX_API int line() const {
    return line_;
  }

  /**
   * @brief Returns the support level of the op schema.
   */
  ONNX_API SupportType support_level() const {
    return support_;
  }

  /**
   * @brief Returns the docstring of the op schema.
   */
  ONNX_API const char* doc() const {
    return doc_.empty() ? nullptr : doc_.c_str();
  }

  // Check if input and output types fall into valid set and match each other
  ONNX_API void CheckInputOutputType(struct InferenceContext&) const;

  /**
   * @brief Verifies if a NodeProto matches the pattern specified in
   * the schema.
   */
  void Verify(const NodeProto& node) const;

  // Functions to set the property of the operator schemas.
  // Sets the number of inputs, either a fixed number or a min and a max.

  /**
   * The earliest operator set version which this operator was
   * present in.  If an operator has had no BC-breaking changes,
   * this is simply the first operator set the operator was a member
   * of; if it has had BC-breaking changes, then for the semantics
   * /as described/ in the OpSchema entry, this version describes
   * the operator set which introduced the BC-breaking change.
   *
   * For example, suppose op Foo was added in v3, and had a BC-breaking
   * change in v6.  Then there will be an op schema entry for Foo with
   * SinceVersion(3), and another, updated op schema entry for Foo
   * with SinceVersion(6).
   */
  ONNX_API OpSchema& SinceVersion(OperatorSetVersion n); // aka int

  /**
   * Marks this op as deprecated as of it's since_version. This will cause the
   * Schema() lookup functions to return nullptr when the version is in the
   * deprecated range.
   */
  ONNX_API OpSchema& Deprecate();

  ONNX_API bool Deprecated() const {
    return deprecated_;
  }

  /**
   * @brief Input could be one of the values specified in allowed_input_nums.
   */
  ONNX_API OpSchema& NumInputs(std::unordered_set<int> allowed_input_nums);

  /**
   * @brief Output could be one of the values specified in allowed_output_nums.
   */
  ONNX_API OpSchema& NumOutputs(std::unordered_set<int> allowed_output_nums);

  // Shape Inference
  //
  // Note that signatures are defined to allow for forward-declaring
  // any structs used from ir.h
  ONNX_API OpSchema& TypeAndShapeInferenceFunction(InferenceFunction inferenceFunction);
  InferenceFunction GetTypeAndShapeInferenceFunction() const {
    return tensor_inference_function_ ? tensor_inference_function_ : dummyInferenceFunction;
  }

  ONNX_API OpSchema& PartialDataPropagationFunction(DataPropagationFunction dataPropagationFunction);
  ONNX_API DataPropagationFunction GetDataPropagationFunction() const {
    return data_propagation_function_ ? data_propagation_function_ : dummyDataPropagationFunction;
  }

  // Set the support level for the op schema.
  ONNX_API OpSchema& SetSupportLevel(SupportType supportType);

  // Functions to do documentation for the operator schema.
  // This may be disabled to save memory.
  ONNX_API OpSchema& SetDoc(const char* doc) {
#ifndef __ONNX_NO_DOC_STRINGS
    SetDoc(std::string(doc));
#else
    ONNX_UNUSED_PARAMETER(doc);
#endif

    return *this;
  }

  ONNX_API OpSchema& SetDoc(const std::string& doc) {
#ifndef __ONNX_NO_DOC_STRINGS
    doc_ = doc;
#else
    ONNX_UNUSED_PARAMETER(doc);
#endif
    return *this;
  }

  // Functions to specify name for the operator schema.
  ONNX_API OpSchema& SetName(const char* name);
  ONNX_API OpSchema& SetName(std::string name);

  // Functions to specify code location for the operator schema.
  ONNX_API OpSchema& SetLocation(const char* file, int line);
  ONNX_API OpSchema& SetLocation(std::string file, int line);

  // Functions to specify domain for the operator schema.
  // Default domain value (ONNX_DOMAIN) means it's ONNX domain.
  ONNX_API OpSchema& SetDomain(const char* domain);
  ONNX_API OpSchema& SetDomain(std::string domain);

  struct Attribute final {
    Attribute(std::string name_, std::string description_, AttributeProto::AttributeType type_, bool required_)
        : name(std::move(name_)), description(std::move(description_)), type(type_), required(required_) {}

    Attribute(std::string name_, std::string description_, AttributeProto default_value_)
        : name(std::move(name_)),
          description(std::move(description_)),
          type(default_value_.type()),
          required(false),
          default_value(std::move(default_value_)) {}

    const std::string name;
    const std::string description;
    AttributeProto::AttributeType type;
    bool required;
    AttributeProto default_value;
  };

  ONNX_API OpSchema& Attr(Attribute attr);

// Register "optional" attribute with default value.
#define ATTR_SETTER_WITH_DEFAULT_VALUE(TypeName)                                                                    \
  OpSchema& Attr(                                                                                                   \
      std::string name, std::string description, AttributeProto::AttributeType type, const TypeName& defaultValue); \
  /* non-STL wrapper to reduce binary size */                                                                       \
  OpSchema& Attr(                                                                                                   \
      const char* name, const char* description, AttributeProto::AttributeType type, const TypeName& defaultValue); \
  OpSchema& Attr(                                                                                                   \
      std::string name,                                                                                             \
      std::string description,                                                                                      \
      AttributeProto::AttributeType type,                                                                           \
      const std::vector<TypeName>& defaultValue);

  ATTR_SETTER_WITH_DEFAULT_VALUE(int64_t)
  ATTR_SETTER_WITH_DEFAULT_VALUE(float)
  ATTR_SETTER_WITH_DEFAULT_VALUE(std::string)
  ATTR_SETTER_WITH_DEFAULT_VALUE(TensorProto)
  ATTR_SETTER_WITH_DEFAULT_VALUE(GraphProto)
  ATTR_SETTER_WITH_DEFAULT_VALUE(TypeProto)

  ONNX_API OpSchema& Attr(
      std::string name,
      std::string description,
      std::string conditionExplanation,
      AttributeProto::AttributeType attr_type);

  // Register "required" attribute without default value.
  ONNX_API OpSchema&
  Attr(std::string name, std::string description, AttributeProto::AttributeType type, bool required = true);

  // Non-STL wrapper to reduce binary size
  ONNX_API OpSchema&
  Attr(const char* name, const char* description, AttributeProto::AttributeType type, bool required = true);

  ONNX_API OpSchema& AllowUncheckedAttributes();

  // Type constraint.
  struct TypeConstraintParam final {
    TypeConstraintParam(
        std::string type_param_str_,
        std::vector<std::string> allowed_type_strs_,
        std::string description_)
        : type_param_str(std::move(type_param_str_)),
          allowed_type_strs(std::move(allowed_type_strs_)),
          description(std::move(description_)) {}

    // Type parameter string, for example, "T", "T1", etc.
    std::string type_param_str;
    // Allowed type strings for <*this> type parameter, for example,
    // "tensor(float)".
    std::vector<std::string> allowed_type_strs;
    // Type parameter description.
    std::string description;
  };

  // Grammar for type strings used in Input(), Output().
  // <type> ::= <data_type> |
  //            tensor(<data_type>) |
  //            seq(<type>) |
  //            map(<data_type>, <type>) |
  //            <type_parameter>
  // <data_type> :: = float | int32 | string | bool | uint8
  //                | int8 | uint16 | int16 | int64 | float16 | double
  // <type_parameter> ::= any type parameter string, say "T".
  //
  // NOTE: 1) <type_parameter> will always be together with a type constraints
  // specification.
  //       2) <type> ::= <data_type> means the data is scalar (zero dimension).
  //
  // Example:
  // ONNX_OPERATOR_SET_SCHEMA(Sum, 1, OpSchema()
  // .Input(0, "input_a", "the first input", "T")
  // .Input(1, "input_b", "the second input", "T")
  // .Output(0, "sum", "the sum of two numbers", "T")
  // .TypeConstraint("T", {"float", "double", "int32"}, "allowed data types for
  // sum."))
  //
  // Optional = true means that the input might have empty input value
  // (represented as "") in the graph even though the later inputs have values.
  // It's useful for complex situation when there are several independent
  // optional inputs.
  ONNX_API OpSchema& Input(int n, FormalParameter formal_parameter);

  ONNX_API OpSchema& Input(
      int n,
      std::string name,
      const std::string& description,
      std::string type_str,
      FormalParameterOption param_option = Single,
      bool is_homogeneous = true,
      int min_arity = 1,
      DifferentiationCategory differentiation_category = Unknown);

  // Non-STL wrapper to reduce binary size
  ONNX_API OpSchema& Input(
      int n,
      const char* name,
      const char* description,
      const char* type_str,
      FormalParameterOption param_option = Single,
      bool is_homogeneous = true,
      int min_arity = 1,
      DifferentiationCategory differentiation_category = Unknown);

  ONNX_API OpSchema& Output(int n, FormalParameter formal_parameter);

  ONNX_API OpSchema& Output(
      int n,
      std::string name,
      const std::string& description,
      std::string type_str,
      FormalParameterOption param_option = Single,
      bool is_homogeneous = true,
      int min_arity = 1,
      DifferentiationCategory differentiation_category = Unknown);

  // Non-STL wrapper to reduce binary size
  ONNX_API OpSchema& Output(
      int n,
      const char* name,
      const char* description,
      const char* type_str,
      FormalParameterOption param_option = Single,
      bool is_homogeneous = true,
      int min_arity = 1,
      DifferentiationCategory differentiation_category = Unknown);

  ONNX_API OpSchema&
  TypeConstraint(std::string type_str, std::vector<std::string> constraints, std::string description);

  // Non-STL wrapper to reduce binary size
  ONNX_API OpSchema&
  TypeConstraint(const char* type_str, std::initializer_list<const char*> constraints, const char* description);

  // Convenience members for types

  // All high-precision numeric types.
  ONNX_API static const std::vector<std::string>& numeric_types_for_math_reduction_ir10() {
    return numeric_types_for_math_reduction_ir9();
  }

  ONNX_API static const std::vector<std::string>& numeric_types_for_math_reduction_ir9();

  ONNX_API static const std::vector<std::string>& numeric_types_for_math_reduction_ir4();

  ONNX_API static const std::vector<std::string>& numeric_types_for_math_reduction();

  ONNX_API static const std::vector<std::string>& all_numeric_types_ir13();

  ONNX_API static const std::vector<std::string>& all_numeric_types_ir12();

  ONNX_API static const std::vector<std::string>& all_numeric_types_ir11();

  ONNX_API static const std::vector<std::string>& all_numeric_types_ir10();

  ONNX_API static const std::vector<std::string>& all_numeric_types_ir9();

  ONNX_API static const std::vector<std::string>& all_numeric_types_ir4();

  ONNX_API static const std::vector<std::string>& all_numeric_types();

  ONNX_API static const std::vector<std::string>& all_numeric_sequence_types();

  ONNX_API static const std::vector<std::string>& all_tensor_types();

  ONNX_API static const std::vector<std::string>& all_tensor_types_ir4();

  ONNX_API static const std::vector<std::string>& all_non_complex_numeric_types_plus_bool_ir4();

  ONNX_API static const std::vector<std::string>& all_float_types_ir4();

  ONNX_API static const std::vector<std::string>& all_float_types_plus_Xint8_ir4();

  ONNX_API static const std::vector<std::string>& all_float_types_ir9();

  ONNX_API static const std::vector<std::string>& all_float_types_ir10() {
    return all_float_types_ir9();
  }

  ONNX_API static const std::vector<std::string>& all_tensor_types_ir9();

  ONNX_API static const std::vector<std::string>& all_tensor_types_ir10();

  ONNX_API static const std::vector<std::string>& all_non_complex_tensor_types_ir10();

  ONNX_API static const std::vector<std::string>& all_tensor_types_ir11();

  ONNX_API static const std::vector<std::string>& all_non_complex_tensor_types_ir11();

  ONNX_API static const std::vector<std::string>& all_tensor_types_ir12();

  ONNX_API static const std::vector<std::string>& all_non_complex_tensor_types_ir12();

  ONNX_API static const std::vector<std::string>& all_tensor_types_ir13();

  ONNX_API static const std::vector<std::string>& all_non_complex_tensor_types_ir13();

  ONNX_API static const std::vector<std::string>& all_tensor_sequence_types();

  ONNX_API static const std::vector<std::string>& all_tensor_sequence_types_ir4();

  ONNX_API static const std::vector<std::string>& all_tensor_sequence_types_ir9();

  ONNX_API static const std::vector<std::string>& all_tensor_sequence_types_ir10();

  ONNX_API static const std::vector<std::string>& all_tensor_sequence_types_ir11();

  ONNX_API static const std::vector<std::string>& all_tensor_sequence_types_ir12();

  ONNX_API static const std::vector<std::string>& all_tensor_sequence_types_ir13();

  ONNX_API static const std::vector<std::string>& all_optional_types();

  ONNX_API static const std::vector<std::string>& all_optional_types_ir4();

  ONNX_API static const std::vector<std::string>& all_optional_types_ir9();

  ONNX_API static const std::vector<std::string>& all_optional_types_ir10();

  ONNX_API static const std::vector<std::string>& all_optional_types_ir11();

  ONNX_API static const std::vector<std::string>& all_optional_types_ir12();

  ONNX_API static const std::vector<std::string>& all_optional_types_ir13();

  // Calls the passed function with `this` as an argument. Useful for
  // adding docs for templated/macro ops.
  ONNX_API OpSchema& FillUsing(const std::function<void(OpSchema&)>& populator);

  friend std::ostream& operator<<(std::ostream& out, const OpSchema& schema);

  ONNX_API const std::string& domain() const {
    return domain_;
  }

  ONNX_API const std::unordered_map<std::string, Attribute>& attributes() const {
    return attributes_;
  }

  // Get input formal parameters.
  ONNX_API const std::vector<FormalParameter>& inputs() const {
    return inputs_;
  }

  // Get output formal parameters.
  ONNX_API const std::vector<FormalParameter>& outputs() const {
    return outputs_;
  }

  ONNX_API const std::vector<TypeConstraintParam>& typeConstraintParams() const {
    return type_constraint_params_;
  }

  ONNX_API const TypeConstraintMap& typeConstraintMap() const {
    return type_constraints_;
  }

  ONNX_API const std::string& Name() const {
    return name_;
  }

  ONNX_API OperatorSetVersion SinceVersion() const {
    return since_version_;
  }

  ONNX_API int since_version() const {
    return since_version_;
  }

  ONNX_API bool deprecated() const {
    return deprecated_;
  }

  ONNX_API int min_input() const {
    return min_input_;
  }

  ONNX_API int max_input() const {
    return max_input_;
  }

  ONNX_API int min_output() const {
    return min_output_;
  }

  ONNX_API int max_output() const {
    return max_output_;
  }

  ONNX_API bool has_type_and_shape_inference_function() const {
    return static_cast<bool>(tensor_inference_function_);
  }

  ONNX_API bool has_data_propagation_function() const {
    return static_cast<bool>(data_propagation_function_);
  }

  ONNX_API std::vector<int> function_opset_versions() const {
    std::vector<int> opset_versions;
    opset_versions.reserve(opset_version_to_function_body_.size());
    for (const auto& pair : opset_version_to_function_body_) {
      opset_versions.push_back(pair.first);
    }
    return opset_versions;
  }

  ONNX_API bool HasFunction() const {
    return !opset_version_to_function_body_.empty();
  }

  ONNX_API OpSchema& FunctionBody(
      const std::vector<NodeProto>& func_nodes,
      int opset_version = kUninitializedSinceVersion);

  ONNX_API OpSchema& FunctionBody(
      const std::vector<NodeProto>& func_nodes,
      const std::vector<OperatorSetIdProto>& opsets,
      int opset_version = kUninitializedSinceVersion);

  ONNX_API OpSchema& FunctionBody(const char* func_body, int opset_version = kUninitializedSinceVersion);

  // since_version_ of an OpSchema tells the last opset version when an op is defined.
  // When the op's definition is changed, a new OpSchema (of the same op_type) is created
  // with a newer since_version_, reflecting the opset version at the time of change.
  // For a function op, operators used to define its function body may change
  // while there is no change to the function op definition itself.
  // When this happens, multiple function bodies are provided, each for a specific opset version.
  //
  // Take LogSoftmax for example. Its latest opset version is 13.
  // In LogSoftmax's function body, ReduceMax (with since_version_ 1, 11, 12, 18) is used.
  // When a model containing LogSoftmax with opset_import version within 13 to 17 is loaded, function body
  // with opset_version 13 is used for inlining.
  // When the same model but opset_import version 18 is loaded, function body
  // with opset_version 18 is used for inlining.
  // Clearly function body for opset_import version 13 will not work
  // in a model with opset_import version 18 because the function body make wrong use of ReduceMax(18).
  // Inside GetFunction we ensure that ops being used to construct a function body do not endure such
  // issue.
  ONNX_API const FunctionProto* GetFunction(
      int requested_opset_version = OpSchema::kUninitializedSinceVersion,
      bool validate = false) const;

  ONNX_API std::vector<int> context_dependent_function_opset_versions() const {
    std::vector<int> opset_versions;
    opset_versions.reserve(opset_version_to_function_builder_.size());
    for (const auto& pair : opset_version_to_function_builder_) {
      opset_versions.push_back(pair.first);
    }
    return opset_versions;
  }

  ONNX_API bool HasContextDependentFunction() const {
    return !opset_version_to_function_builder_.empty();
  }

  ONNX_API bool HasContextDependentFunctionWithOpsetVersion(int opset_version) const {
    return opset_version_to_function_builder_.find(opset_version) != opset_version_to_function_builder_.end();
  }

  ONNX_API OpSchema& SetContextDependentFunctionBodyBuilder(
      ContextDependentFunctionBodyBuilder,
      int opset_version = kUninitializedSinceVersion);

  ONNX_API bool BuildContextDependentFunction(
      const FunctionBodyBuildContext& ctx,
      FunctionProto& function_proto,
      int requested_opset_version = OpSchema::kUninitializedSinceVersion) const;

  // Verifies that the schema is valid and all specifications are compatible.
  // It will also parse all type strings specified for inputs/outputs into valid
  // TypeProto and create global unique string pointer as the DataType for
  // efficiency.
  ONNX_API void Finalize();

  // Build function with information stored in opschema
  ONNX_API void BuildFunction(FunctionProto& function_body) const;

  ONNX_API NodeDeterminism GetNodeDeterminism() const;
  ONNX_API OpSchema& SetNodeDeterminism(NodeDeterminism node_determinism);

 private:
  void ParseAndSetTypes(
      /*out*/ std::vector<OpSchema::FormalParameter>* formalParameters);
  bool ValidateReferencedOpsInFunction(
      const FunctionProto* function,
      int requested_opset_version,
      int function_since_version,
      std::unordered_set<std::string>* updated_ops = nullptr) const;
  void UpdateFunctionProtoOpsetImportVersion(FunctionProto& function_proto, int opset_version) const;

  /**
   * @brief A common function to generate a prefix string for use in fail_check during the verify function.
   * @param  node_name If empty, the returned string will not include the node name.
   * @return std::string The prefix string.
   */
  std::string VerifyFailPrefix(std::string_view node_name) const;

  /**
   * @brief Verifies if the input number matches the pattern specified in the schema.
   * @param input_num The number of inputs to be verified against the schema.
   * @param node_name The prefix string used if the check fails.
   */
  void VerifyInputNum(int input_num, std::string_view node_name = "") const;

  /**
   * @brief Verifies if the output number matches the pattern specified in the schema.
   * @param output_num The number of outputs to be verified against the schema.
   * @param node_name The prefix string used if the check fails.
   */
  void VerifyOutputNum(int output_num, std::string_view node_name = "") const;

  std::string name_;
  std::string file_;
  std::string doc_;
  // Default domain value ("") means it's ONNX domain.
  std::string domain_ = ONNX_DOMAIN;
  std::unordered_map<std::string, Attribute> attributes_;
  bool allows_unchecked_attributes_ = false;
  std::vector<FormalParameter> inputs_;
  std::vector<FormalParameter> outputs_;
  std::vector<TypeConstraintParam> type_constraint_params_;
  TypeConstraintMap type_constraints_;
  int line_ = 0;
  SupportType support_;
  int min_input_ = 0;
  int max_input_ = 0;
  int min_output_ = 0;
  int max_output_ = 0;
  // The default is a little goofy, since it is never what you want
  OperatorSetVersion since_version_ = kUninitializedSinceVersion;
  bool deprecated_{};
  std::function<bool(int)> num_inputs_allowed_ = [](int) { return true; };
  std::function<bool(int)> num_outputs_allowed_ = [](int) { return true; };
  InferenceFunction tensor_inference_function_;
  DataPropagationFunction data_propagation_function_;

  std::map<int, std::shared_ptr<FunctionProto>> opset_version_to_function_body_;
  std::map<int, ContextDependentFunctionBodyBuilder> opset_version_to_function_builder_;

  NodeDeterminism node_determinism_ = NodeDeterminism::Unknown;
};

// Map type to store operator schemas. The format is,
// <OpName, <Domain, <OperatorSetVersion, OpSchema>>>.
using OpName_Domain_Version_Schema_Map =
    std::unordered_map<std::string, std::unordered_map<std::string, std::map<OperatorSetVersion, OpSchema>>>;

class ISchemaRegistry {
 public:
  virtual ~ISchemaRegistry() = default;

  ONNX_API virtual const OpSchema*
  // NOLINTNEXTLINE(google-default-arguments)
  GetSchema(const std::string& key, const int maxInclusiveVersion, const std::string& domain = ONNX_DOMAIN) const = 0;
};

/**
 * @brief A registry to hold all the operator schemas.
 */
class OpSchemaRegistry final : public ISchemaRegistry {
 public:
  // A singleton class to store domain to min/max op_set version map, as well as
  // domain to last-release op_set version map.
  class DomainToVersionRange final {
   public:
    DomainToVersionRange() {
      // Increase the highest version when you make BC-breaking changes to the
      // operator schema on specific domain. Update the lowest version when it's
      // determined to remove too old version history.
      map_[ONNX_DOMAIN] = std::make_pair(1, 25);
      map_[AI_ONNX_ML_DOMAIN] = std::make_pair(1, 5);
      map_[AI_ONNX_TRAINING_DOMAIN] = std::make_pair(1, 1);
      // ONNX's preview domain contains operators subject to change, so
      // versioning is not meaningful and that domain should have only one
      // version.
      map_[AI_ONNX_PREVIEW_TRAINING_DOMAIN] = std::make_pair(1, 1);
      // Version corresponding last release of ONNX. Update this to match with
      // the max version above in a *release* version of ONNX. But in other
      // versions, the max version may be ahead of the last-release-version.
      last_release_version_map_[ONNX_DOMAIN] = 25;
      last_release_version_map_[AI_ONNX_ML_DOMAIN] = 5;
      last_release_version_map_[AI_ONNX_TRAINING_DOMAIN] = 1;
      last_release_version_map_[AI_ONNX_PREVIEW_TRAINING_DOMAIN] = 1;
    }

    ONNX_API const std::unordered_map<std::string, std::pair<int, int>>& Map() const {
      return map_;
    }

    ONNX_API const std::unordered_map<std::string, int>& LastReleaseVersionMap() const {
      return last_release_version_map_;
    }

    // Add customized domain to min/max version.
    // Onnx partners are able to use onnx operator schema api to
    // register customized op in their own domain.
    // Can optionally specify last_release_version (to make it similar to
    // standard ONNX domains as above). Custom-domains are free to interpret
    // this as appropriate (that is, as relative to releases of custom-domain
    // as opposed to ONNX releases).
    ONNX_API void
    AddDomainToVersion(const std::string& domain, int min_version, int max_version, int last_release_version = -1) {
      std::lock_guard<std::mutex> lock(mutex_);
      if (map_.count(domain) != 0) {
        std::stringstream err;
        err << "Trying to add a domain to DomainToVersion map, but the domain is already exist with version range ("
            << map_.at(domain).first << ", " << map_.at(domain).second << "). domain: \"" << domain << "\"" << '\n';
        fail_schema(err.str());
      }
      if (last_release_version_map_.count(domain) != 0) {
        std::stringstream err;
        err << "Trying to add a domain to LastReleaseVersion map, but the domain is already exist with last version: "
            << last_release_version_map_.at(domain) << ", domain: \"" << domain << "\"" << '\n';
        fail_schema(err.str());
      }
      map_[domain] = std::make_pair(min_version, max_version);
      // If a last-release-version is not explicitly specified, use max as
      // last-release-version.
      if (last_release_version == -1) {
        last_release_version = max_version;
      }
      last_release_version_map_[domain] = last_release_version;
    }

    ONNX_API void
    UpdateDomainToVersion(const std::string& domain, int min_version, int max_version, int last_release_version = -1) {
      std::lock_guard<std::mutex> lock(mutex_);
      if (map_.count(domain) == 0) {
        std::stringstream err;
        err << "Trying to update a domain in DomainToVersion map, but the domain has not been add. domain: \"" << domain
            << "\"" << '\n';
        fail_schema(err.str());
      }
      if (last_release_version_map_.count(domain) == 0) {
        std::stringstream err;
        err << "Trying to update a domain in LastReleaseVersion map, but the domain has not been add. domain: \""
            << domain << "\"" << '\n';
        fail_schema(err.str());
      }
      map_.at(domain).first = min_version;
      map_.at(domain).second = max_version;
      // Correspond to `AddDomainToVersion`
      if (last_release_version == -1) {
        last_release_version = max_version;
      }
      last_release_version_map_.at(domain) = last_release_version;
    }

    ONNX_API static DomainToVersionRange& Instance();

   private:
    // Key: domain. Value: <lowest version, highest version> pair.
    std::unordered_map<std::string, std::pair<int, int>> map_;

    // Key: domain. Value: most recent release opset version. Note that
    // the highest opset version may be ahead of the most recent release's opset
    // version.
    std::unordered_map<std::string, int> last_release_version_map_;

    std::mutex mutex_;
  };

  class OpSchemaRegisterOnce final {
   public:
    // Export to cpp custom register macro.
    // DO NOT decorate the constructor as "explicit" because that breaks the macro ONNX_OPERATOR_SCHEMA_UNIQ.
    // NOLINTNEXTLINE(google-explicit-constructor)
    OpSchemaRegisterOnce( // NOSONAR
        OpSchema op_schema,
        int opset_version_to_load = 0,
        bool fail_duplicate_schema = true) {
      OpSchemaRegisterNoExcept(std::move(op_schema), opset_version_to_load, fail_duplicate_schema);
    }
    ONNX_API static void
    OpSchemaRegisterNoExcept(OpSchema&& op_schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true) {
      ONNX_TRY {
        OpSchemaRegisterImpl(std::move(op_schema), opset_version_to_load, fail_duplicate_schema);
      }
      ONNX_CATCH(const std::exception& e) {
        ONNX_HANDLE_EXCEPTION([&]() { std::cerr << "Schema error: " << e.what() << '\n'; });
      }
    }
    ONNX_API static void
    OpSchemaRegisterImpl(OpSchema&& op_schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true) {
      op_schema.Finalize();
      auto& m = GetMapWithoutEnsuringRegistration();
      auto& op_name = op_schema.Name();
      auto& op_domain = op_schema.domain();
      auto& schema_ver_map = m[op_name][op_domain];
      auto ver = op_schema.SinceVersion();
      if (OpSchema::kUninitializedSinceVersion == ver) {
        op_schema.SinceVersion(1);
        ver = op_schema.SinceVersion();
      }

      // Stops because the exact opset_version is registered
      if (schema_ver_map.count(ver)) {
        if (fail_duplicate_schema) {
          const auto& schema = schema_ver_map[ver];
          std::stringstream err;
          err << "Trying to register schema with name " << op_name << " (domain: " << op_domain << " version: " << ver
              << ") from file " << op_schema.file() << " line " << op_schema.line()
              << ", but it is already registered from file " << schema.file() << " line " << schema.line() << '\n';
          fail_schema(err.str());
        }
        return;
      }

      if (opset_version_to_load != 0) {
        // Stops because the opset_version is higher than opset_version_to_load
        if (ver > opset_version_to_load) {
          return;
        }

        // Stops because a later version is registered within target opset version
        if (!schema_ver_map.empty()) {
          int max_registered_ver_le_target = GetMaxRegisteredVerWithinTarget(schema_ver_map, opset_version_to_load);
          if (max_registered_ver_le_target >= ver) {
            return;
          }
        }
      }

      CheckDomainAndVersionToRegister(op_schema, op_name, op_domain);
      schema_ver_map.insert(std::pair<int, OpSchema&&>(ver, std::move(op_schema)));
    }

   private:
    // Gets the maximum version from given map that is less or equal to target version
    static int GetMaxRegisteredVerWithinTarget(const std::map<OperatorSetVersion, OpSchema>& m, int target_ver) {
      // std::map is sorted on key
      // reverse iterator returns the largest element keyed on the integer version
      for (auto&& it = m.rbegin(); it != m.rend(); it++) {
        const auto& registered_ver = it->first;
        if (registered_ver <= target_ver) {
          return registered_ver;
        }
      }
      return -1;
    }

    ONNX_API static void CheckDomainAndVersionToRegister(
        const OpSchema& op_schema,
        const std::string& op_name,
        const std::string& op_domain) {
      auto ver_range_map = DomainToVersionRange::Instance().Map();
      auto ver_range_it = ver_range_map.find(op_domain);
      auto ver = op_schema.SinceVersion();

      if (ver_range_it == ver_range_map.end()) {
        std::stringstream err;
        err << "Trying to register schema with name " << op_name << " (domain: " << op_domain << " version: " << ver
            << ") from file " << op_schema.file() << " line " << op_schema.line() << ", but its domain is not"
            << " known by the checker." << '\n';

        fail_schema(err.str());
      }
      auto lower_bound_incl = ver_range_it->second.first;
      auto upper_bound_incl = ver_range_it->second.second;
      if (lower_bound_incl > ver || upper_bound_incl < ver) {
        std::stringstream err;
        err << "Trying to register schema with name " << op_name << " (domain: " << op_domain << " version: " << ver
            << ") from file " << op_schema.file() << " line " << op_schema.line() << ", but its version is not "
            << "in the inclusive range [" << lower_bound_incl << ", " << upper_bound_incl
            << "] (usually, this means you "
            << "bumped the operator version but "
            << "forgot to update the version range in DomainToVersionRange "
            << "in onnx/defs/schema.h)." << '\n';
        fail_schema(err.str());
      }
    }
  };

  static void
  OpSchemaDeregister(const std::string& op_type, const int version, const std::string& domain = ONNX_DOMAIN) {
    auto& schema_map = GetMapWithoutEnsuringRegistration();
    if (schema_map.count(op_type) && schema_map[op_type].count(domain) && schema_map[op_type][domain].count(version)) {
      schema_map[op_type][domain].erase(version);
    } else {
      std::stringstream err;
      err << "Attempting to deregister an unregistered schema with name: " << op_type << " domain: " << domain
          << " version: " << version << '\n';
      fail_schema(err.str());
    }
  }

  // Deregister all ONNX opset schemas from domain
  // Domain with default value ONNX_DOMAIN means ONNX.
  static void OpSchemaDeregisterAll(const std::string& domain = ONNX_DOMAIN) {
    auto& schema_map = GetMapWithoutEnsuringRegistration();
    // schema_map stores operator schemas in the format of
    // <OpName, <Domain, <OperatorSetVersion, OpSchema>>>
    for (auto&& schema_map_pair : schema_map) {
      auto& domain_map = schema_map_pair.second;
      if (domain_map.count(domain)) {
        auto& opset_version_schema_map = domain_map[domain];
        // Invalidates ver-schema pairs and frees memory, leaving m[op_name][op_domain] empty
        opset_version_schema_map.clear();
        domain_map.erase(domain);
      }
    }
  }

  // Return the latest schema for an operator in specified domain.
  // Domain with default value ONNX_DOMAIN means ONNX.
  static const OpSchema* Schema(const std::string& key, const std::string& domain = ONNX_DOMAIN) {
    auto& m = map();
    if (m.count(key) && m[key].count(domain)) {
      const auto& schema_ver_map = m[key][domain];
      if (!schema_ver_map.empty()) {
        return &m[key][domain].rbegin()->second;
      }
    }
    return nullptr;
  }

  // Return the schema with biggest version, which is not greater than specified
  // <maxInclusiveVersion> in specified domain. Domain with default value
  // ONNX_DOMAIN means ONNX.
  ONNX_API static const OpSchema*
  Schema(const std::string& key, const int maxInclusiveVersion, const std::string& domain = ONNX_DOMAIN) {
    auto& m = map();
    if (m.count(key) && m[key].count(domain)) {
      const auto& schema_ver_map = m[key][domain];
      if (!schema_ver_map.empty()) {
        auto pos = m[key][domain].lower_bound(maxInclusiveVersion);
        if (m[key][domain].begin() == pos && pos->first > maxInclusiveVersion) {
          // All versions are greater than specified version.
          return nullptr;
        }
        if (m[key][domain].end() == pos || pos->first > maxInclusiveVersion) {
          // All versions are less than specified version, or,
          // The <pos> version is greater than specified version.
          pos--;
        }

        // Schema with exact version as specified one exists.
        return &(pos->second);
      }
    }
    return nullptr;
  }

  ONNX_API static OpSchemaRegistry* Instance();

  // NOLINTNEXTLINE(google-default-arguments)
  ONNX_API const OpSchema* GetSchema(
      const std::string& key,
      const int maxInclusiveVersion,
      const std::string& domain = ONNX_DOMAIN) const override {
    return Schema(key, maxInclusiveVersion, domain);
  }
  ONNX_API static void SetLoadedSchemaVersion(int target_version) {
    loaded_schema_version = target_version;
  }
  ONNX_API static int GetLoadedSchemaVersion() {
    return loaded_schema_version;
  }

 private:
  // OpSchemaRegistry should not need to be instantiated except statically
  // within this class
  OpSchemaRegistry() = default;

  /**
   * @brief Returns the underlying string to OpSchema map.
   *
   * You should not manually manipulate the map object returned. Instead, use
   * the macros defined such as ONNX_OPERATOR_SET_SCHEMA to register your
   * operator schema.
   *
   * We wrap it inside a function to avoid the static initialization order
   * fiasco.
   *
   * With the change in function visibility, the
   * GetMapWithoutEnsuringRegistration() and map() methods cannot be used
   * to access the schema map directly from outside the OpSchemaRegistry class.
   * Hence the ONNX_API macro is used to ensure that the methods are
   * accessible from other translation units providing backward compatibility.
   */
  ONNX_API static OpName_Domain_Version_Schema_Map& GetMapWithoutEnsuringRegistration();
  ONNX_API static OpName_Domain_Version_Schema_Map& map();
  static int loaded_schema_version;

 public:
  static std::vector<OpSchema> get_all_schemas_with_history() {
    std::vector<OpSchema> r;
    for (auto& x : map()) {
      for (auto& y : x.second) {
        for (auto& z : y.second) {
          r.emplace_back(z.second);
        }
      }
    }
    return r;
  }

  static std::vector<OpSchema> get_all_schemas() {
    std::vector<OpSchema> r;
    for (auto& x : map()) {
      for (auto& y : x.second) {
        auto& version2schema = y.second;
        if (!version2schema.empty()) {
          r.emplace_back(version2schema.rbegin()->second);
        }
      }
    }
    return r;
  }
};

ONNX_API void RegisterSchema(
    const OpSchema& schema,
    int opset_version_to_load = 0,
    bool fail_duplicate_schema = true,
    bool fail_with_exception = false);
ONNX_API void RegisterSchema(
    OpSchema&& schema,
    int opset_version_to_load = 0,
    bool fail_duplicate_schema = true,
    bool fail_with_exception = false);
ONNX_API void DeregisterSchema(const std::string& op_type, int version, const std::string& domain);

// Registers the latest opset schema before opset_version_to_load
// By default opset_version_to_load=0 means it will register all versions
template <class T>
ONNX_API void RegisterOpSetSchema(int opset_version_to_load = 0, bool fail_duplicate_schema = true) {
  T::ForEachSchema([opset_version_to_load, fail_duplicate_schema](OpSchema&& schema) {
    RegisterSchema(std::move(schema), opset_version_to_load, fail_duplicate_schema);
  });
}

// Forward declaration for the non-specialized GetOpSchema method.  This
// enforces a consistent signature on functions that query individual schema,
// which are defined as specializations of this function.
template <typename T>
ONNX_API OpSchema GetOpSchema();

#define ONNX_OPERATOR_SET_SCHEMA(name, ver, impl) ONNX_OPERATOR_SET_SCHEMA_EX(name, Onnx, ONNX_DOMAIN, ver, true, impl)

#define ONNX_ML_OPERATOR_SET_SCHEMA(name, ver, impl) \
  ONNX_OPERATOR_SET_SCHEMA_EX(name, OnnxML, AI_ONNX_ML_DOMAIN, ver, true, impl)

#define ONNX_TRAINING_OPERATOR_SET_SCHEMA(name, ver, impl) \
  ONNX_OPERATOR_SET_SCHEMA_EX(name, OnnxTraining, AI_ONNX_TRAINING_DOMAIN, ver, true, impl)

#define ONNX_PREVIEW_TRAINING_OPERATOR_SET_SCHEMA(name, ver, impl) \
  ONNX_OPERATOR_SET_SCHEMA_EX(name, OnnxPreview, AI_ONNX_PREVIEW_TRAINING_DOMAIN, ver, true, impl)

#ifdef NDEBUG
#define ONNX_DBG_INCREMENT_COUNT_IN_OPSETS() 0
#define ONNX_OPERATOR_SET_SCHEMA_DEBUG_VARIABLE(name, domain, ver, dbg_included_in_static_opset) \
  static size_t dbg_count_check_##name##_##domain##_ver##ver [[maybe_unused]] = 0
#else
class DbgOperatorSetTracker {
 public:
  ONNX_API static DbgOperatorSetTracker& Instance();

  ONNX_API size_t IncrementCount() {
    return ++count_;
  }

  ONNX_API size_t GetCount() const {
    return count_;
  }

 private:
  size_t count_ = 0;
};
#define ONNX_DBG_INCREMENT_COUNT_IN_OPSETS() DbgOperatorSetTracker::Instance().IncrementCount()
#define ONNX_OPERATOR_SET_SCHEMA_DEBUG_VARIABLE(name, domain, ver, dbg_included_in_static_opset) \
  static size_t dbg_count_check_##name##_##domain##_ver##ver [[maybe_unused]] =                  \
      (dbg_included_in_static_opset) ? ONNX_DBG_INCREMENT_COUNT_IN_OPSETS() : 0;
#endif

// Defines specialization of GetOpSchema for a class whose name is determined
// based on a convention using name, domain, and version.  Operator schema are
// normally included in operator sets and registered in OpSchemaRegistry::map().
// In this case, callers should set dbg_included_in_static_opset to true.  This
// assists with runtime validation in DEBUG builds ensuring the intended set
// of operator schema is registered.

#define ONNX_OPERATOR_SET_SCHEMA_EX(name, domain, domain_str, ver, dbg_included_in_static_opset, impl)  \
  class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(domain, ver, name);                                         \
  template <>                                                                                           \
  ONNX_API OpSchema GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(domain, ver, name)>() {             \
    return impl.SetName(#name).SetDomain(domain_str).SinceVersion(ver).SetLocation(__FILE__, __LINE__); \
  }                                                                                                     \
  ONNX_OPERATOR_SET_SCHEMA_DEBUG_VARIABLE(domain, ver, name, dbg_included_in_static_opset)
#ifndef NDEBUG
#define ONNX_DBG_GET_COUNT_IN_OPSETS() DbgOperatorSetTracker::Instance().GetCount()

#endif

// Naming convention for operator schema classes
#define ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(domain, ver, name) name##_##domain##_ver##ver

// Naming convention for preview operator schema classes
#define ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(ver, name) \
  ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxPreview, ver, name)

// Helper function
size_t ReplaceAll(std::string& s, const char* from, const char* to);

#ifdef __GNUC__
#define ONNX_UNUSED __attribute__((__unused__))
#else
#define ONNX_UNUSED
#endif

// Legacy macros to register schema at static initialization
#define ONNX_OPERATOR_SCHEMA(name) ONNX_OPERATOR_SCHEMA_UNIQ_HELPER(__COUNTER__, name)
#define ONNX_OPERATOR_SCHEMA_UNIQ_HELPER(Counter, name) ONNX_OPERATOR_SCHEMA_UNIQ(Counter, name)
#define ONNX_OPERATOR_SCHEMA_UNIQ(Counter, name)                                                                     \
  static ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce op_schema_register_once##name##Counter ONNX_UNUSED = \
      ONNX_NAMESPACE::OpSchema(#name, __FILE__, __LINE__)

ONNX_API inline std::string GenerateOptionalArgumentsDoc() {
  return "This operator has **optional** inputs/outputs. "
         "See [the doc](IR.md) for more details about the representation of "
         "optional arguments. An empty string may be used in the place of "
         "an actual argument's name to indicate a missing argument. "
         "Trailing optional arguments (those not followed by an argument "
         "that is present) may also be simply omitted.\n";
}

ONNX_API inline std::string GenerateBroadcastingDocMul() {
  return "This operator supports **multidirectional (i.e., Numpy-style) broadcasting**;"
         " for more details please check [the doc](Broadcasting.md).";
}

ONNX_API inline std::string GenerateBroadcastingDocUni(const char* from, const char* to) {
  std::string ret = "This operator supports **unidirectional broadcasting** (";
  ret = ret + from + " should be unidirectional broadcastable to " + to +
      ");"
      " for more details please check [the doc](Broadcasting.md).";
  return ret;
}

/*
 * Macros for setting operator documentation
 * Use this macro for simple SetDoc() calls that generate documentation
 * directly. This is the macro to use in almost all cases.
 * Sample usage guidelines:
 * const char* doc_str = "foo";
 * SetDoc(GET_OP_DOC_STR(doc_str))
 *
 * SetDoc(GET_OP_DOC_STR(
            std::string(BitShift_ver11_doc) + GenerateBroadcastingDocMul()))
 */
#ifndef __ONNX_NO_DOC_STRINGS
#define GET_OP_DOC_STR(doc_str) (doc_str)
#else
#define GET_OP_DOC_STR(doc_str) ("")
#endif

/*
 * Use this macro when the documentation needs to be populated in some
 * complicated way like string substitutions, etc before calling SetDoc.
 * Sample usage guidelines:
    std::string doc;
    POPULATE_OP_DOC_STR(
        doc = R"DOC(
Returns the tensor resulted from performing the `{name}` logical operation
elementwise on the input tensors `A` and `B` (with Numpy-style broadcasting
support).

{broadcast_doc}
)DOC";
        ReplaceAll(doc, "{name}", name);
        ReplaceAll(
            doc, "{broadcast_doc}", GenerateBroadcastingDocMul().c_str()););
    schema.SetDoc(doc);
 *
 */
#ifndef __ONNX_NO_DOC_STRINGS
#define POPULATE_OP_DOC_STR(DocPopulatorCode) \
  do {                                        \
    DocPopulatorCode                          \
  } while (0)
#else
#define POPULATE_OP_DOC_STR(DocPopulatorCode)
#endif

} // namespace ONNX_NAMESPACE
