/*
 * SPDX-License-Identifier: Apache-2.0
 */

#include "onnx/defs/schema.h"

#include <stdexcept>
#include <string>
#include <string_view>
#include <unordered_map>
#include <unordered_set>
#include <utility>

#include "onnx/checker.h"
#include "onnx/defs/operator_sets.h"
#include "onnx/defs/operator_sets_preview.h"
#include "onnx/defs/operator_sets_training.h"

#ifdef ONNX_ML
#include "onnx/defs/operator_sets_ml.h"
#endif

#ifndef NDEBUG
#include "onnx/common/assertions.h"
#endif
#include "onnx/defs/parser.h"

namespace ONNX_NAMESPACE {
// -1 means ONNX schema hasn't been loaded yet
// 0 means all versions of ONNX schema have been loaded
// Other positive integer means the ONNX schemas for the specified version have been loaded
int OpSchemaRegistry::loaded_schema_version = -1;

// By default if opset_version_to_load=0, it registers all opset schema for all opset versions
// Otherwise, it only registers the latest schema according to opset_version_to_load
void RegisterSchema(
    const OpSchema& schema,
    int opset_version_to_load,
    bool fail_duplicate_schema,
    bool fail_with_exception) {
  RegisterSchema(OpSchema(schema), opset_version_to_load, fail_duplicate_schema, fail_with_exception);
}
void RegisterSchema(
    OpSchema&& schema,
    int opset_version_to_load,
    bool fail_duplicate_schema,
    bool fail_with_exception) {
  if (fail_with_exception) {
    OpSchemaRegistry::OpSchemaRegisterOnce::OpSchemaRegisterImpl(
        std::move(schema), opset_version_to_load, fail_duplicate_schema);
  } else {
    OpSchemaRegistry::OpSchemaRegisterOnce::OpSchemaRegisterNoExcept(
        std::move(schema), opset_version_to_load, fail_duplicate_schema);
  }
}

// The (name, version, domain) must match the target exactly
// Otherwise will raise an SchemaError
void DeregisterSchema(const std::string& op_type, int version, const std::string& domain) {
  OpSchemaRegistry::OpSchemaDeregister(op_type, version, domain);
}

#ifndef NDEBUG
DbgOperatorSetTracker& DbgOperatorSetTracker::Instance() {
  static DbgOperatorSetTracker instance;
  return instance;
}
#endif

const std::string& OpSchema::FormalParameter::GetName() const {
  return name_;
}

const DataTypeSet& OpSchema::FormalParameter::GetTypes() const {
  return type_set_;
}

DataTypeSet& OpSchema::FormalParameter::MutableTypes() {
  return type_set_;
}

const std::string& OpSchema::FormalParameter::GetTypeStr() const {
  return type_str_;
}

const std::string& OpSchema::FormalParameter::GetDescription() const {
  return description_;
}

OpSchema::FormalParameterOption OpSchema::FormalParameter::GetOption() const {
  return param_option_;
}

bool OpSchema::FormalParameter::GetIsHomogeneous() const {
  return is_homogeneous_;
}

int OpSchema::FormalParameter::GetMinArity() const {
  return min_arity_;
}

OpSchema::DifferentiationCategory OpSchema::FormalParameter::GetDifferentiationCategory() const {
  return differentiation_category_;
}

OpSchemaRegistry* OpSchemaRegistry::Instance() {
  static OpSchemaRegistry instance;
  return &instance;
}

void OpSchema::CheckInputOutputType(struct InferenceContext& ctx) const {
  std::unordered_map<std::string, std::string> type_constraints;
  // Check the number of inputs / output.
  VerifyInputNum(ctx.getNumInputs());
  VerifyOutputNum(ctx.getNumOutputs());

  // check all input types
  for (size_t in_idx = 0; in_idx < ctx.getNumInputs(); ++in_idx) {
    // If the last input is Variadic by definition, checker still needs to check the rest of actual input's type
    const auto& param = (in_idx < inputs_.size()) ? inputs_[in_idx] : inputs_.back();
    const auto& type_str = param.GetTypeStr();
    const auto& param_type = ctx.getInputType(in_idx);
    const auto& all_types = param.GetTypes();
    if (nullptr == param_type || param_type->value_case() == TypeProto::VALUE_NOT_SET) {
      continue;
    } else if (!all_types.empty() && all_types.find(Utils::DataTypeUtils::ToType(*param_type)) == all_types.end()) {
      fail_check(
          param.GetName(),
          " typestr: ",
          type_str,
          ", has unsupported type: ",
          *Utils::DataTypeUtils::ToType(*param_type));
    }
    if (param.GetIsHomogeneous()) {
      const auto& type_proto = Utils::DataTypeUtils::ToType(*param_type);
      auto p = type_constraints.emplace(type_str, *type_proto);
      if (!p.second) {
        // failed to insert a new element due to a duplication, now check consistency
        if (p.first->second != *type_proto) {
          fail_check(param.GetName(), " has inconsistent type ", *Utils::DataTypeUtils::ToType(*param_type));
        }
      }
    }
  } // for inputs
  // check all output types
  for (size_t out_idx = 0; out_idx < ctx.getNumOutputs(); ++out_idx) {
    // If the last output is Variadic by definition, checker still needs to check the rest of actual output's type
    const auto& param = (out_idx < outputs_.size()) ? outputs_[out_idx] : outputs_.back();
    const auto& type_str = param.GetTypeStr();
    const auto& param_type = ctx.getOutputType(out_idx);
    const auto& all_types = param.GetTypes();
    bool output_type_found = true;
    // infer type if necessary
    if (param_type->value_case() == TypeProto::VALUE_NOT_SET) {
      if (all_types.size() == 1) {
        *param_type = Utils::DataTypeUtils::ToTypeProto(*all_types.begin());
      } else if (type_constraints.find(type_str) != type_constraints.end()) {
        auto data_type = Utils::DataTypeUtils::ToType(type_constraints[type_str]);
        *param_type = Utils::DataTypeUtils::ToTypeProto(data_type);
      } else {
        output_type_found = false;
      }
    }
    if (!output_type_found) {
      continue;
    }
    if (!all_types.empty() && all_types.find(Utils::DataTypeUtils::ToType(*param_type)) == all_types.end()) {
      fail_check(param.GetName(), " has unsupported type ", *Utils::DataTypeUtils::ToType(*param_type));
    }
    if (param.GetIsHomogeneous()) {
      const auto& type_proto = Utils::DataTypeUtils::ToType(*param_type);
      if (type_constraints.find(type_str) == type_constraints.end()) {
        type_constraints[type_str] = *type_proto;
      } else if (type_constraints[type_str] != *type_proto) {
        fail_check(param.GetName(), " has inconsistent type ", *Utils::DataTypeUtils::ToType(*param_type));
      }
    } // else
  } // for outputs
}

void OpSchema::Verify(const NodeProto& node) const {
  if (deprecated_) {
    fail_check("Operator '", name_, "' has been deprecated since version ", since_version_);
  }

  VerifyInputNum(node.input_size(), node.name());
  VerifyOutputNum(node.output_size(), node.name());

  // Check the values of inputs / outputs
  for (int in_idx = 0; in_idx < node.input_size(); ++in_idx) {
    if (in_idx >= static_cast<int>(inputs_.size())) {
      if (!inputs_.empty() && Variadic == inputs_.back().GetOption()) {
        // The last input formal parameter should be variadic.
        break;
      } else {
        fail_check(
            "Node (",
            node.name(),
            ") has more inputs (",
            node.input_size(),
            ") than declared (",
            inputs_.size(),
            ") in op definition.");
      }
    }
    if (node.input(in_idx).empty() && (Single == inputs_[in_idx].GetOption())) {
      fail_check("Node (", node.name(), ")'s input ", in_idx, " is marked single but has an empty string in the graph");
    }
  }

  for (int out_idx = 0; out_idx < node.output_size(); ++out_idx) {
    if (out_idx >= static_cast<int>(outputs_.size())) {
      if (!outputs_.empty() && Variadic == outputs_.back().GetOption()) {
        // The last output formal parameter should be variadic.
        break;
      } else {
        fail_check(
            "Node (",
            node.name(),
            ") has more outputs (",
            node.output_size(),
            ") than declared (",
            outputs_.size(),
            ") in op definition.");
      }
    }

    if (node.output(out_idx).empty() && (Single == outputs_[out_idx].GetOption())) {
      fail_check(
          "Node (", node.name(), ")'s output ", out_idx, " is marked single but has an empty string in the graph");
    }
  }

  // An internal symbol is defined as starting with two underscores. Attributes
  // with names meeting this condition are considered implementation details
  // and should be ignored for the purpose of schema checking.
  auto isInternalSymbol = [](const std::string& sym) -> bool {
    return sym.length() >= 2 && sym[0] == '_' && sym[1] == '_';
  };

  // Check attributes
  std::unordered_set<std::string> seen_attr_names{};
  for (const auto& attr_proto : node.attribute()) {
    const auto& name = attr_proto.name();

    if (!seen_attr_names.insert(name).second) {
      fail_check("Attribute '", name, "' appeared multiple times.");
    };

    const auto& search = attributes_.find(name);
    AttributeProto::AttributeType expected_type{};
    if (search != attributes_.end()) {
      expected_type = search->second.type;
    } else if (allows_unchecked_attributes_ || isInternalSymbol(name)) {
      continue;
    } else {
      fail_check("Unrecognized attribute: ", name, " for operator ", node.op_type());
    }

    // Type would be UNDEFINED if not set
    if (attr_proto.type() != expected_type) {
      fail_check(
          "Mismatched attribute type in '",
          node.name() + " : " + name,
          "'. Expected: '",
          AttributeProto_AttributeType_Name(expected_type),
          "', actual: '",
          AttributeProto_AttributeType_Name(attr_proto.type()),
          "'");
    }

    // ref_attr_name is only valid when non-empty
    // we simply read default value if not present
    if (!attr_proto.ref_attr_name().empty()) {
      continue;
    }

    switch (expected_type) {
      // if attr_proto().type() != UNDEFINED
      // we consider primitive types to be set even
      // if proto3 did not output default values into the stream
      // in which case we will read the default
      case AttributeProto::FLOAT:
      case AttributeProto::INT:
      case AttributeProto::STRING:
        break;
      case AttributeProto::TENSOR:
        if (!attr_proto.has_t()) {
          fail_check("Attribute '", name, "' is expected to have field 't'");
        }
        break;
      case AttributeProto::SPARSE_TENSOR:
        if (!attr_proto.has_sparse_tensor()) {
          fail_check("Attribute '", name, "' is expected to have field 'sparse_tensor'");
        }
        break;
      case AttributeProto::GRAPH:
        if (!attr_proto.has_g()) {
          fail_check("Attribute '", name, "' is expected to have field 'g'");
        }
        break;
      case AttributeProto::TYPE_PROTO:
        if (!attr_proto.has_tp()) {
          fail_check("Attribute '", name, "' is expected to have field 'type_proto'");
        }
        break;
      case AttributeProto::INTS:
      case AttributeProto::FLOATS:
      case AttributeProto::TENSORS:
      case AttributeProto::STRINGS:
      case AttributeProto::SPARSE_TENSORS:
      case AttributeProto::GRAPHS:
      case AttributeProto::TYPE_PROTOS:
        // No check ... whether an empty list is a valid value for the attribute
        // is op specific.
        break;
      default:
        fail_check("Attribute '", name, " has unknown expected type");
    }
  }
  for (const auto& pair : attributes_) {
    const auto& attr = pair.second;
    if (!attr.required) {
      continue;
    }
    if (!seen_attr_names.count(attr.name)) {
      fail_check("Required attribute '", attr.name, "' is missing.");
    }
  }

  // Phew. All verifications passed.
}

std::string OpSchema::VerifyFailPrefix(std::string_view node_name) const {
  std::string str = "Node";
  if (!node_name.empty()) {
    str = str + "(" + std::string(node_name) + ")";
  }
  str = str + " with schema(" + domain() + "::" + Name() + ":" + std::to_string(since_version()) + ")";
  return str;
}

void OpSchema::VerifyInputNum(int input_num, std::string_view node_name) const {
  if (input_num < min_input_ || input_num > max_input_) {
    fail_check(
        VerifyFailPrefix(node_name),
        " has input size ",
        input_num,
        " not in range [min=",
        min_input_,
        ", max=",
        max_input_,
        "].");
  }

  if (!num_inputs_allowed_(input_num)) {
    fail_check(VerifyFailPrefix(node_name), " has input size ", input_num, " not in allowed input sizes.");
  }
}

void OpSchema::VerifyOutputNum(int output_num, std::string_view node_name) const {
  if (output_num < min_output_ || output_num > max_output_) {
    fail_check(
        VerifyFailPrefix(node_name),
        " has output size ",
        output_num,
        " not in range [min=",
        min_output_,
        ", max=",
        max_output_,
        "].");
  }

  if (!num_outputs_allowed_(output_num)) {
    fail_check(VerifyFailPrefix(node_name), " has output size ", output_num, " not in allowed output sizes.");
  }
}

OpSchema& OpSchema::SinceVersion(OperatorSetVersion v) {
  since_version_ = v;

  // SinceVersion is called after FunctionBody and SetContextDependentFunctionBodyBuilder are called
  // when defining a op.
  // FunctionBody() and SetContextDependentFunctionBodyBuilder() use -1 as the default opset_version
  // default opset_version is for a FunctionProto of the same opset_version as the op's since_version_.
  // It is indexed with -1 so we need to reindex it with since_version_.
  //
  // FunctionProtos of non-default opset_versions are for models whose opset version is higher than the op's
  // opset version such that ops used in the default function_proto are no longer valid. For example:
  // A model of opset version 18 contains a LayerNormalization op.
  // LayerNormalization is function op whese function body uses ReduceMean op.
  // LayerNormalization's since_version is 17 thus it is good for the model of opset 18.
  // however, if a runtime needs to inline LayerNormalization, the inlined model has a ReduceMean op.
  // ReduceMean in opset 18 is different from opset 17.
  // This requires us to define more than one function body
  auto it = opset_version_to_function_builder_.find(OpSchema::kUninitializedSinceVersion);
  if (it != opset_version_to_function_builder_.end()) {
    opset_version_to_function_builder_[since_version_] = it->second;
    opset_version_to_function_builder_.erase(it);
  }

  auto it_function_body = opset_version_to_function_body_.find(OpSchema::kUninitializedSinceVersion);
  if (it_function_body != opset_version_to_function_body_.end()) {
    opset_version_to_function_body_[since_version_] = it_function_body->second;
    UpdateFunctionProtoOpsetImportVersion(*opset_version_to_function_body_[since_version_], since_version_);
    opset_version_to_function_body_.erase(it_function_body);
  }

  return *this;
}

OpSchema& OpSchema::Deprecate() {
  deprecated_ = true;
  return *this;
}

OpSchema& OpSchema::NumInputs(std::unordered_set<int> allowed_input_nums) {
  num_inputs_allowed_ = [allowed_input_nums = std::move(allowed_input_nums)](int n) -> bool {
    return allowed_input_nums.count(n);
  };
  return *this;
}

OpSchema& OpSchema::NumOutputs(std::unordered_set<int> allowed_output_nums) {
  num_outputs_allowed_ = [allowed_output_nums = std::move(allowed_output_nums)](int n) -> bool {
    return allowed_output_nums.count(n) > 0;
  };
  return *this;
}

OpSchema& OpSchema::TypeAndShapeInferenceFunction(InferenceFunction inferenceFunction) {
  tensor_inference_function_ = std::move(inferenceFunction);
  return *this;
}

OpSchema& OpSchema::PartialDataPropagationFunction(DataPropagationFunction dataPropagationFunction) {
  data_propagation_function_ = std::move(dataPropagationFunction);
  return *this;
}

OpSchema& OpSchema::SetSupportLevel(SupportType support) {
  support_ = support;
  return *this;
}

// Functions to specify name for the operator schema.
OpSchema& OpSchema::SetName(std::string name) {
  name_ = std::move(name);
  return *this;
}

OpSchema& OpSchema::SetName(const char* name) {
  return SetName(std::string(name));
}

// Functions to specify code location for the operator schema.
OpSchema& OpSchema::SetLocation(std::string file, int line) {
  file_ = std::move(file);
  line_ = line;
  return *this;
}

OpSchema& OpSchema::SetLocation(const char* file, int line) {
  return SetLocation(std::string(file), line);
}

OpSchema& OpSchema::SetDomain(std::string domain) {
  domain_ = std::move(domain);
  return *this;
}

OpSchema& OpSchema::SetDomain(const char* domain) {
  return SetDomain(std::string(domain));
}

OpSchema& OpSchema::Attr(Attribute attr) {
  auto name = attr.name; // copy name so we can move attr in the next line
  attributes_.emplace(std::move(name), std::move(attr));
  return *this;
}

OpSchema& OpSchema::Attr(std::string name, std::string description, AttributeProto::AttributeType type, bool required) {
  Attr(Attribute{std::move(name), std::move(description), type, required});
  return *this;
}

OpSchema& OpSchema::Attr(const char* name, const char* description, AttributeProto::AttributeType type, bool required) {
  return Attr(std::string(name), std::string(description), type, required);
}

#define ATTR_SETTER_WITH_SINGLE_VALUE(type, field, attrtype)                                                           \
  OpSchema& OpSchema::Attr(                                                                                            \
      std::string name, std::string description, AttributeProto::AttributeType attr_type, const type& default_value) { \
    if (attrtype != attr_type) {                                                                                       \
      fail_schema("Attribute specification type mismatch.");                                                           \
    }                                                                                                                  \
    AttributeProto a;                                                                                                  \
    a.set_name(name);                                                                                                  \
    a.set_##field(default_value);                                                                                      \
    a.set_type(attr_type);                                                                                             \
    Attr(Attribute(std::move(name), std::move(description), std::move(a)));                                            \
    return *this;                                                                                                      \
  }                                                                                                                    \
  OpSchema& OpSchema::Attr(                                                                                            \
      const char* name, const char* description, AttributeProto::AttributeType attr_type, const type& default_value) { \
    return Attr(std::string(name), std::string(description), attr_type, default_value);                                \
  }

#define ATTR_SETTER_WITH_LIST_VALUE(type, field, attrtype)                  \
  OpSchema& OpSchema::Attr(                                                 \
      std::string name,                                                     \
      std::string description,                                              \
      AttributeProto::AttributeType attr_type,                              \
      const std::vector<type>& default_value) {                             \
    if (attrtype != attr_type) {                                            \
      fail_schema("Attribute specification type mismatch.");                \
    }                                                                       \
    AttributeProto a;                                                       \
    a.set_name(name);                                                       \
    a.set_type(attr_type);                                                  \
    for (const auto& v : default_value) {                                   \
      a.add_##field(v);                                                     \
    }                                                                       \
    Attr(Attribute(std::move(name), std::move(description), std::move(a))); \
    return *this;                                                           \
  }

#define ATTR_SETTER_WITH_SINGLE_COMPLEXVALUE(type, field, attrtype)                                                    \
  OpSchema& OpSchema::Attr(                                                                                            \
      std::string name, std::string description, AttributeProto::AttributeType attr_type, const type& default_value) { \
    if (attrtype != attr_type) {                                                                                       \
      fail_schema("Attribute specification type mismatch.");                                                           \
    }                                                                                                                  \
    AttributeProto a;                                                                                                  \
    a.set_name(name);                                                                                                  \
    *(a.mutable_##field()) = default_value;                                                                            \
    a.set_type(attr_type);                                                                                             \
    Attr(Attribute(std::move(name), std::move(description), a));                                                       \
    return *this;                                                                                                      \
  }

#define ATTR_SETTER_WITH_LIST_COMPLEXVALUE(type, field, attrtype)           \
  OpSchema& OpSchema::Attr(                                                 \
      std::string name,                                                     \
      std::string description,                                              \
      AttributeProto::AttributeType attr_type,                              \
      const std::vector<type>& default_value) {                             \
    if (attrtype != attr_type) {                                            \
      fail_schema("Attribute specification type mismatch.");                \
    }                                                                       \
    AttributeProto a;                                                       \
    a.set_name(name);                                                       \
    a.set_type(attr_type);                                                  \
    for (const auto& v : default_value) {                                   \
      *(a.add_##field()) = v;                                               \
    }                                                                       \
    Attr(Attribute(std::move(name), std::move(description), std::move(a))); \
    return *this;                                                           \
  }

ATTR_SETTER_WITH_SINGLE_VALUE(int64_t, i, AttributeProto::INT)
ATTR_SETTER_WITH_SINGLE_VALUE(float, f, AttributeProto::FLOAT)
ATTR_SETTER_WITH_SINGLE_VALUE(std::string, s, AttributeProto::STRING)
ATTR_SETTER_WITH_SINGLE_COMPLEXVALUE(TensorProto, t, AttributeProto::TENSOR)
ATTR_SETTER_WITH_SINGLE_COMPLEXVALUE(GraphProto, g, AttributeProto::GRAPH)
ATTR_SETTER_WITH_SINGLE_COMPLEXVALUE(TypeProto, tp, AttributeProto::TYPE_PROTO)
ATTR_SETTER_WITH_LIST_VALUE(int64_t, ints, AttributeProto::INTS)
ATTR_SETTER_WITH_LIST_VALUE(float, floats, AttributeProto::FLOATS)
ATTR_SETTER_WITH_LIST_COMPLEXVALUE(std::string, strings, AttributeProto::STRINGS)
ATTR_SETTER_WITH_LIST_COMPLEXVALUE(TensorProto, tensors, AttributeProto::TENSORS)
ATTR_SETTER_WITH_LIST_COMPLEXVALUE(GraphProto, graphs, AttributeProto::GRAPHS)
ATTR_SETTER_WITH_LIST_COMPLEXVALUE(TypeProto, type_protos, AttributeProto::TYPE_PROTOS)

OpSchema& OpSchema::AllowUncheckedAttributes() {
  allows_unchecked_attributes_ = true;
  return *this;
}

OpSchema& OpSchema::Input(int n, FormalParameter formal_parameter) {
  if (inputs_.size() <= static_cast<size_t>(n)) {
    inputs_.resize(n + 1);
  }
  inputs_[n] = std::move(formal_parameter);
  return *this;
}

OpSchema& OpSchema::Input(
    int n,
    std::string name,
    const std::string& description,
    std::string type_str,
    OpSchema::FormalParameterOption param_option,
    bool is_homogeneous,
    int min_arity,
    DifferentiationCategory differentiation_category) {
  return Input(
      n,
      FormalParameter(
          std::move(name),
#ifndef __ONNX_NO_DOC_STRINGS
          description,
#else
          std::string(),
#endif
          std::move(type_str),
          param_option,
          is_homogeneous,
          min_arity,
          differentiation_category));
}

OpSchema& OpSchema::Input(
    int n,
    const char* name,
    const char* description,
    const char* type_str,
    FormalParameterOption param_option,
    bool is_homogeneous,
    int min_arity,
    DifferentiationCategory differentiation_category) {
  return Input(
      n,
      std::string(name),
#ifndef __ONNX_NO_DOC_STRINGS
      std::string(description),
#else
      std::string(),
#endif
      std::string(type_str),
      param_option,
      is_homogeneous,
      min_arity,
      differentiation_category);
}

OpSchema& OpSchema::Output(int n, FormalParameter formal_parameter) {
  if (outputs_.size() <= static_cast<size_t>(n)) {
    outputs_.resize(n + 1);
  }
  outputs_[n] = std::move(formal_parameter);
  return *this;
}

OpSchema& OpSchema::Output(
    int n,
    std::string name,
    const std::string& description,
    std::string type_str,
    OpSchema::FormalParameterOption param_option,
    bool is_homogeneous,
    int min_arity,
    DifferentiationCategory differentiation_category) {
  return Output(
      n,
      FormalParameter(
          std::move(name),
#ifndef __ONNX_NO_DOC_STRINGS
          description,
#else
          std::string(),
#endif
          std::move(type_str),
          param_option,
          is_homogeneous,
          min_arity,
          differentiation_category));
}

OpSchema& OpSchema::Output(
    int n,
    const char* name,
    const char* description,
    const char* type_str,
    FormalParameterOption param_option,
    bool is_homogeneous,
    int min_arity,
    DifferentiationCategory differentiation_category) {
  return Output(
      n,
      std::string(name),
#ifndef __ONNX_NO_DOC_STRINGS
      std::string(description),
#else
      std::string(),
#endif
      std::string(type_str),
      param_option,
      is_homogeneous,
      min_arity,
      differentiation_category);
}

OpSchema&
OpSchema::TypeConstraint(std::string type_str, std::vector<std::string> constraints, std::string description) {
  if (type_constraints_.end() != type_constraints_.find(type_str)) {
    fail_schema("Duplicate type constraint name");
  }

  DataTypeSet d;
  for (const auto& t : constraints) {
    d.insert(Utils::DataTypeUtils::ToType(t));
  }
  type_constraints_.emplace(type_str, std::make_pair(d, description));
  type_constraint_params_.emplace_back(std::move(type_str), std::move(constraints), std::move(description));
  return *this;
}

OpSchema& OpSchema::TypeConstraint(
    const char* type_str,
    std::initializer_list<const char*> constraints,
    const char* description) {
  std::vector<std::string> constraints_vector;
  constraints_vector.reserve(constraints.size());
  for (auto constraint : constraints) {
    constraints_vector.emplace_back(constraint);
  }

  return TypeConstraint(std::string(type_str), constraints_vector, std::string(description));
}

void OpSchema::ParseAndSetTypes(
    /*out*/ std::vector<OpSchema::FormalParameter>* formal_parameters) {
  for (auto& formal_parameter : *formal_parameters) {
    auto& type = formal_parameter.GetTypeStr();
    DataTypeSet allowed_types;
    auto it = type_constraints_.find(type);
    if (it != type_constraints_.end()) {
      allowed_types = it->second.first;
    } else {
      allowed_types.emplace(Utils::DataTypeUtils::ToType(type));
    }

    formal_parameter.MutableTypes() = allowed_types;
  }
}

OpSchema& OpSchema::SetContextDependentFunctionBodyBuilder(
    ContextDependentFunctionBodyBuilder functionBuilder,
    int opset_version) {
  if (opset_version == OpSchema::kUninitializedSinceVersion && since_version_ != OpSchema::kUninitializedSinceVersion) {
    opset_version_to_function_builder_[since_version_] = std::move(functionBuilder);
  } else {
    opset_version_to_function_builder_[opset_version] = std::move(functionBuilder);
  }
  return *this;
}

bool OpSchema::BuildContextDependentFunction(
    const FunctionBodyBuildContext& ctx,
    FunctionProto& function_proto,
    int requested_opset_version) const {
  if (requested_opset_version == OpSchema::kUninitializedSinceVersion)
    requested_opset_version = since_version_;

  auto it = opset_version_to_function_builder_.upper_bound(requested_opset_version);
  if (opset_version_to_function_builder_.empty() || it == opset_version_to_function_builder_.begin()) {
    ONNX_THROW_EX(
        std::out_of_range(
            std::string("Cannot find a function builder that satisfies the requested opset version: op_type = ") +
            this->name_ + ", opset_version = " + std::to_string(requested_opset_version) + "."));
  } else {
    --it;
    const ContextDependentFunctionBodyBuilder& body_builder = it->second;
    if (!body_builder(ctx, *this, function_proto)) {
      return false;
    }
    //// default opset import may have been added to function_proto by OpSchema::BuildFunction
    //// we need to update its version with the specified opset_version
    UpdateFunctionProtoOpsetImportVersion(function_proto, requested_opset_version);
    ValidateReferencedOpsInFunction(&function_proto, requested_opset_version, it->first);
    return true;
  }
}

// A function of a schema (either stored in opset_version_to_function_body_ or built with one of function builder
// in opset_version_to_function_builder_) has predefined opset_imports. Before returning the function, we shall
// update the predefined opset_imports so that it is consistent with the requested version.
// Note that this call only update opset_import of the default domain.
// TODO: extend this call to work for no-default domains.
void OpSchema::UpdateFunctionProtoOpsetImportVersion(FunctionProto& function_proto, int requested_opset_version) const {
  bool opset_import_exist = false;
  for (int i = 0; i < function_proto.opset_import_size(); i++) {
    auto* schema_opset = function_proto.mutable_opset_import(i);
    if (schema_opset->domain() == domain_) {
      if (schema_opset->version() != requested_opset_version) {
        schema_opset->set_version(requested_opset_version);
      }
      opset_import_exist = true;
    }
  }

  if (!opset_import_exist) {
    auto* schema_opset = function_proto.mutable_opset_import()->Add();
    schema_opset->set_domain(domain_);
    schema_opset->set_version(requested_opset_version);
  }
}

OpSchema& OpSchema::FunctionBody(const char* func_body, int opset_version) {
  if (opset_version == OpSchema::kUninitializedSinceVersion && since_version_ != OpSchema::kUninitializedSinceVersion) {
    opset_version = since_version_;
  }
  auto function_proto = std::make_shared<FunctionProto>();
  OnnxParser parser(func_body);
  auto status = parser.Parse(*function_proto->mutable_node());
  if (!status.IsOK())
    ONNX_THROW_EX(std::logic_error("Error parsing function body:" + status.ErrorMessage()));
  if (!parser.EndOfInput())
    ONNX_THROW_EX(std::logic_error("Extra unparsed input unexpected."));

  // opset import may have been set
  // we may need to update its version with the specified opset_version
  UpdateFunctionProtoOpsetImportVersion(*function_proto, opset_version);

  opset_version_to_function_body_.emplace(opset_version, function_proto);
  return *this;
}

OpSchema& OpSchema::FunctionBody(const std::vector<NodeProto>& func_nodes, int opset_version) {
  if (opset_version == OpSchema::kUninitializedSinceVersion && since_version_ != OpSchema::kUninitializedSinceVersion) {
    opset_version = since_version_;
  }
  auto function_proto = std::make_shared<FunctionProto>();
  for (const auto& node : func_nodes) {
    auto new_node = function_proto->add_node();
    new_node->CopyFrom(node);
  }

  // opset import may have been set
  // we may need to update its version with the specified opset_version
  UpdateFunctionProtoOpsetImportVersion(*function_proto, opset_version);
  opset_version_to_function_body_.emplace(opset_version, std::move(function_proto));
  return *this;
}

OpSchema& OpSchema::FunctionBody(
    const std::vector<NodeProto>& func_nodes,
    const std::vector<OperatorSetIdProto>& relied_opsets,
    int opset_version) {
  if (opset_version == OpSchema::kUninitializedSinceVersion && since_version_ != OpSchema::kUninitializedSinceVersion) {
    opset_version = since_version_;
  }

  auto function_proto = std::make_shared<FunctionProto>();
  for (auto& relied_opset : relied_opsets) {
    *(function_proto->mutable_opset_import()->Add()) = relied_opset;
  }

  for (const auto& node : func_nodes) {
    auto new_node = function_proto->add_node();
    new_node->CopyFrom(node);
  }
  // opset import may have been set
  // we may need to update its version with the specified opset_version
  UpdateFunctionProtoOpsetImportVersion(*function_proto, opset_version);
  opset_version_to_function_body_.emplace(opset_version, std::move(function_proto));
  return *this;
}

const FunctionProto* OpSchema::GetFunction(int requested_opset_version, bool validate) const {
  if (opset_version_to_function_body_.empty())
    return nullptr;
  // Return latest FunctionProto when opset version request is not set
  if (requested_opset_version == OpSchema::kUninitializedSinceVersion) {
    return opset_version_to_function_body_.rbegin()->second.get();
  }
  auto it = opset_version_to_function_body_.upper_bound(requested_opset_version);
  if (it != opset_version_to_function_body_.begin()) {
    --it;
    int function_since_version = it->first;
    const FunctionProto* function = it->second.get();
    if (!validate || ValidateReferencedOpsInFunction(function, requested_opset_version, function_since_version)) {
      return function;
    }
  }
  return nullptr;
}

// when requesting a function at loading time,
// requested_opset_version does not have to be the same as function_since_version.
// When they are not the same, it is necessary to verify that ops used to define the function
// are not updated between function_since_version and requested_opset_version (include requested_opset_version).
// this call only validate ops in the default domain.
// TODO: validate ops in other domains.
bool OpSchema::ValidateReferencedOpsInFunction(
    const FunctionProto* function,
    int requested_opset_version,
    int function_since_version,
    std::unordered_set<std::string>* updated_ops) const {
  bool all_ops_are_invalid = true;
  if (requested_opset_version == function_since_version) {
    return all_ops_are_invalid;
  }
  for (auto& node : function->node()) {
    if (node.domain().empty() || node.domain() == "ai.onnx") {
      const OpSchema* op1 =
          OpSchemaRegistry::Instance()->GetSchema(node.op_type(), requested_opset_version, node.domain());
      const OpSchema* op2 =
          OpSchemaRegistry::Instance()->GetSchema(node.op_type(), function_since_version, node.domain());
      if (op1 != op2) {
        if (updated_ops) {
          updated_ops->insert(node.op_type());
        }
        all_ops_are_invalid = false;
      }
    }
  }

  return all_ops_are_invalid;
}

OpSchema& OpSchema::FillUsing(const std::function<void(OpSchema&)>& populator) {
  if (populator) {
    populator(*this);
  }
  return *this;
}

void OpSchema::BuildFunction(FunctionProto& function_body) const {
  function_body.set_name(this->name_);
  function_body.set_doc_string(this->doc_);
  function_body.set_domain(this->domain_);
  for (auto& i : inputs_) {
    function_body.add_input(i.GetName());
  }
  for (auto& o : outputs_) {
    function_body.add_output(o.GetName());
  }
  for (auto& a : attributes_) {
    function_body.add_attribute(a.first);
  }

  // In a typical onnx function where the function and all the
  // ops in function body belong to the same domain we implicitly add
  // {domain_, since_version_} to function opset imports if it is not already added.
  // This is simply for convienince. If any of the function body ops do not belong to same
  // domain as function itself, then the function author needs to explicitly add all the relevant
  // opset imports.
  if (function_body.opset_import().empty()) {
    auto* schema_opset = function_body.mutable_opset_import()->Add();
    schema_opset->set_domain(domain_);
    schema_opset->set_version(since_version_);
  }
}

const std::vector<std::string>& OpSchema::numeric_types_for_math_reduction_ir9() {
  static const std::vector<std::string> numeric_types_for_math_reduction_ir9 = {
      "tensor(uint32)",
      "tensor(uint64)",
      "tensor(int32)",
      "tensor(int64)",
      "tensor(float16)",
      "tensor(float)",
      "tensor(double)",
      "tensor(bfloat16)",
      "tensor(float8e4m3fn)",
      "tensor(float8e4m3fnuz)",
      "tensor(float8e5m2)",
      "tensor(float8e5m2fnuz)"};
  return numeric_types_for_math_reduction_ir9;
}

const std::vector<std::string>& OpSchema::numeric_types_for_math_reduction_ir4() {
  static const std::vector<std::string> numeric_types_for_math_reduction_ir4 = {
      "tensor(uint32)",
      "tensor(uint64)",
      "tensor(int32)",
      "tensor(int64)",
      "tensor(float16)",
      "tensor(float)",
      "tensor(double)",
      "tensor(bfloat16)"};
  return numeric_types_for_math_reduction_ir4;
}

const std::vector<std::string>& OpSchema::numeric_types_for_math_reduction() {
  static const std::vector<std::string> numeric_types_for_math_reduction = {
      "tensor(uint32)",
      "tensor(uint64)",
      "tensor(int32)",
      "tensor(int64)",
      "tensor(float16)",
      "tensor(float)",
      "tensor(double)"};
  return numeric_types_for_math_reduction;
}

const std::vector<std::string>& OpSchema::all_numeric_types_ir12() {
  static const std::vector<std::string> all_numeric_types_ir12 = {
      "tensor(uint8)",        "tensor(uint16)",         "tensor(uint32)",     "tensor(uint64)",
      "tensor(int8)",         "tensor(int16)",          "tensor(int32)",      "tensor(int64)",
      "tensor(float16)",      "tensor(float)",          "tensor(double)",     "tensor(bfloat16)",
      "tensor(float8e4m3fn)", "tensor(float8e4m3fnuz)", "tensor(float8e5m2)", "tensor(float8e5m2fnuz)",
      "tensor(uint4)",        "tensor(int4)",           "tensor(float4e2m1)", "tensor(float8e8m0)"};
  return all_numeric_types_ir12;
}

const std::vector<std::string>& OpSchema::all_numeric_types_ir13() {
  static const std::vector<std::string> all_numeric_types_ir13 = {"tensor(uint8)",        "tensor(uint16)",
                                                                  "tensor(uint32)",       "tensor(uint64)",
                                                                  "tensor(int8)",         "tensor(int16)",
                                                                  "tensor(int32)",        "tensor(int64)",
                                                                  "tensor(float16)",      "tensor(float)",
                                                                  "tensor(double)",       "tensor(bfloat16)",
                                                                  "tensor(float8e4m3fn)", "tensor(float8e4m3fnuz)",
                                                                  "tensor(float8e5m2)",   "tensor(float8e5m2fnuz)",
                                                                  "tensor(uint4)",        "tensor(int4)",
                                                                  "tensor(float4e2m1)",   "tensor(float8e8m0)",
                                                                  "tensor(uint2)",        "tensor(int2)"};
  return all_numeric_types_ir13;
}

const std::vector<std::string>& OpSchema::all_numeric_types_ir11() {
  static const std::vector<std::string> all_numeric_types_ir11 = {
      "tensor(uint8)",
      "tensor(uint16)",
      "tensor(uint32)",
      "tensor(uint64)",
      "tensor(int8)",
      "tensor(int16)",
      "tensor(int32)",
      "tensor(int64)",
      "tensor(float16)",
      "tensor(float)",
      "tensor(double)",
      "tensor(bfloat16)",
      "tensor(float8e4m3fn)",
      "tensor(float8e4m3fnuz)",
      "tensor(float8e5m2)",
      "tensor(float8e5m2fnuz)",
      "tensor(uint4)",
      "tensor(int4)",
      "tensor(float4e2m1)"};
  return all_numeric_types_ir11;
}

const std::vector<std::string>& OpSchema::all_numeric_types_ir10() {
  static const std::vector<std::string> all_numeric_types_ir10 = {
      "tensor(uint8)",
      "tensor(uint16)",
      "tensor(uint32)",
      "tensor(uint64)",
      "tensor(int8)",
      "tensor(int16)",
      "tensor(int32)",
      "tensor(int64)",
      "tensor(float16)",
      "tensor(float)",
      "tensor(double)",
      "tensor(bfloat16)",
      "tensor(float8e4m3fn)",
      "tensor(float8e4m3fnuz)",
      "tensor(float8e5m2)",
      "tensor(float8e5m2fnuz)",
      "tensor(uint4)",
      "tensor(int4)"};
  return all_numeric_types_ir10;
}

const std::vector<std::string>& OpSchema::all_numeric_types_ir9() {
  static const std::vector<std::string> all_numeric_types_ir9 = {
      "tensor(uint8)",
      "tensor(uint16)",
      "tensor(uint32)",
      "tensor(uint64)",
      "tensor(int8)",
      "tensor(int16)",
      "tensor(int32)",
      "tensor(int64)",
      "tensor(float16)",
      "tensor(float)",
      "tensor(double)",
      "tensor(bfloat16)",
      "tensor(float8e4m3fn)",
      "tensor(float8e4m3fnuz)",
      "tensor(float8e5m2)",
      "tensor(float8e5m2fnuz)"};
  return all_numeric_types_ir9;
}

const std::vector<std::string>& OpSchema::all_numeric_types_ir4() {
  static const std::vector<std::string> all_numeric_types_ir4 = {
      "tensor(uint8)",
      "tensor(uint16)",
      "tensor(uint32)",
      "tensor(uint64)",
      "tensor(int8)",
      "tensor(int16)",
      "tensor(int32)",
      "tensor(int64)",
      "tensor(float16)",
      "tensor(float)",
      "tensor(double)",
      "tensor(bfloat16)"};
  return all_numeric_types_ir4;
}

const std::vector<std::string>& OpSchema::all_numeric_types() {
  static const std::vector<std::string> all_numeric_types = {
      "tensor(uint8)",
      "tensor(uint16)",
      "tensor(uint32)",
      "tensor(uint64)",
      "tensor(int8)",
      "tensor(int16)",
      "tensor(int32)",
      "tensor(int64)",
      "tensor(float16)",
      "tensor(float)",
      "tensor(double)"};
  return all_numeric_types;
}

const std::vector<std::string>& OpSchema::all_numeric_sequence_types() {
  static const std::vector<std::string> all_numeric_sequence_types = {
      "seq(tensor(uint8))",
      "seq(tensor(uint16))",
      "seq(tensor(uint32))",
      "seq(tensor(uint64))",
      "seq(tensor(int8))",
      "seq(tensor(int16))",
      "seq(tensor(int32))",
      "seq(tensor(int64))",
      "seq(tensor(float16))",
      "seq(tensor(float))",
      "seq(tensor(double))"};
  return all_numeric_sequence_types;
}

const std::vector<std::string>& OpSchema::all_tensor_types() {
  static const std::vector<std::string> all_tensor_types = {
      "tensor(uint8)",
      "tensor(uint16)",
      "tensor(uint32)",
      "tensor(uint64)",
      "tensor(int8)",
      "tensor(int16)",
      "tensor(int32)",
      "tensor(int64)",
      "tensor(float16)",
      "tensor(float)",
      "tensor(double)",
      "tensor(string)",
      "tensor(bool)",
      "tensor(complex64)",
      "tensor(complex128)"};
  return all_tensor_types;
}

const std::vector<std::string>& OpSchema::all_tensor_types_ir4() {
  static const std::vector<std::string> all_tensor_types_ir4 = {
      "tensor(uint8)",
      "tensor(uint16)",
      "tensor(uint32)",
      "tensor(uint64)",
      "tensor(int8)",
      "tensor(int16)",
      "tensor(int32)",
      "tensor(int64)",
      "tensor(bfloat16)",
      "tensor(float16)",
      "tensor(float)",
      "tensor(double)",
      "tensor(string)",
      "tensor(bool)",
      "tensor(complex64)",
      "tensor(complex128)"};
  return all_tensor_types_ir4;
}

const std::vector<std::string>& OpSchema::all_non_complex_numeric_types_plus_bool_ir4() {
  static const std::vector<std::string> all_non_complex_numeric_types_plus_bool_ir4 = {
      "tensor(uint8)",
      "tensor(uint16)",
      "tensor(uint32)",
      "tensor(uint64)",
      "tensor(int8)",
      "tensor(int16)",
      "tensor(int32)",
      "tensor(int64)",
      "tensor(bfloat16)",
      "tensor(float16)",
      "tensor(float)",
      "tensor(double)",
      "tensor(bool)"};
  return all_non_complex_numeric_types_plus_bool_ir4;
}

const std::vector<std::string>& OpSchema::all_float_types_ir4() {
  static const std::vector<std::string> all_float_types_ir4 = {
      "tensor(bfloat16)", "tensor(float16)", "tensor(float)", "tensor(double)"};
  return all_float_types_ir4;
}

const std::vector<std::string>& OpSchema::all_float_types_plus_Xint8_ir4() {
  static const std::vector<std::string> all_float_types_ir4 = {
      "tensor(bfloat16)", "tensor(float16)", "tensor(float)", "tensor(double)", "tensor(int8)", "tensor(uint8)"};
  return all_float_types_ir4;
}

const std::vector<std::string>& OpSchema::all_float_types_ir9() {
  static const std::vector<std::string> all_float_types_ir9 = {
      "tensor(bfloat16)",
      "tensor(float16)",
      "tensor(float)",
      "tensor(double)",
      "tensor(float8e4m3fn)",
      "tensor(float8e4m3fnuz)",
      "tensor(float8e5m2)",
      "tensor(float8e5m2fnuz)"};
  return all_float_types_ir9;
}

const std::vector<std::string>& OpSchema::all_tensor_types_ir9() {
  static const std::vector<std::string> all_tensor_types_ir9 = {
      "tensor(uint8)",        "tensor(uint16)",         "tensor(uint32)",     "tensor(uint64)",
      "tensor(int8)",         "tensor(int16)",          "tensor(int32)",      "tensor(int64)",
      "tensor(bfloat16)",     "tensor(float16)",        "tensor(float)",      "tensor(double)",
      "tensor(string)",       "tensor(bool)",           "tensor(complex64)",  "tensor(complex128)",
      "tensor(float8e4m3fn)", "tensor(float8e4m3fnuz)", "tensor(float8e5m2)", "tensor(float8e5m2fnuz)"};
  return all_tensor_types_ir9;
}

const std::vector<std::string>& OpSchema::all_tensor_types_ir10() {
  static const std::vector<std::string> all_tensor_types_ir10 = {
      "tensor(uint8)",      "tensor(uint16)",         "tensor(uint32)",
      "tensor(uint64)",     "tensor(int8)",           "tensor(int16)",
      "tensor(int32)",      "tensor(int64)",          "tensor(bfloat16)",
      "tensor(float16)",    "tensor(float)",          "tensor(double)",
      "tensor(string)",     "tensor(bool)",           "tensor(complex64)",
      "tensor(complex128)", "tensor(float8e4m3fn)",   "tensor(float8e4m3fnuz)",
      "tensor(float8e5m2)", "tensor(float8e5m2fnuz)", "tensor(uint4)",
      "tensor(int4)"};
  return all_tensor_types_ir10;
}

const std::vector<std::string>& OpSchema::all_non_complex_tensor_types_ir10() {
  static const std::vector<std::string> all_non_complex_tensor_types_ir10 = {
      "tensor(uint8)",      "tensor(uint16)",         "tensor(uint32)",       "tensor(uint64)",
      "tensor(int8)",       "tensor(int16)",          "tensor(int32)",        "tensor(int64)",
      "tensor(bfloat16)",   "tensor(float16)",        "tensor(float)",        "tensor(double)",
      "tensor(string)",     "tensor(bool)",           "tensor(float8e4m3fn)", "tensor(float8e4m3fnuz)",
      "tensor(float8e5m2)", "tensor(float8e5m2fnuz)", "tensor(uint4)",        "tensor(int4)"};
  return all_non_complex_tensor_types_ir10;
}

const std::vector<std::string>& OpSchema::all_tensor_types_ir11() {
  static const std::vector<std::string> all_tensor_types_ir11 = {
      "tensor(uint8)",        "tensor(uint16)",         "tensor(uint32)",     "tensor(uint64)",
      "tensor(int8)",         "tensor(int16)",          "tensor(int32)",      "tensor(int64)",
      "tensor(bfloat16)",     "tensor(float16)",        "tensor(float)",      "tensor(double)",
      "tensor(string)",       "tensor(bool)",           "tensor(complex64)",  "tensor(complex128)",
      "tensor(float8e4m3fn)", "tensor(float8e4m3fnuz)", "tensor(float8e5m2)", "tensor(float8e5m2fnuz)",
      "tensor(uint4)",        "tensor(int4)",           "tensor(float4e2m1)"};
  return all_tensor_types_ir11;
}

const std::vector<std::string>& OpSchema::all_non_complex_tensor_types_ir11() {
  static const std::vector<std::string> all_non_complex_tensor_types_ir11 = {
      "tensor(uint8)",      "tensor(uint16)",         "tensor(uint32)",       "tensor(uint64)",
      "tensor(int8)",       "tensor(int16)",          "tensor(int32)",        "tensor(int64)",
      "tensor(bfloat16)",   "tensor(float16)",        "tensor(float)",        "tensor(double)",
      "tensor(string)",     "tensor(bool)",           "tensor(float8e4m3fn)", "tensor(float8e4m3fnuz)",
      "tensor(float8e5m2)", "tensor(float8e5m2fnuz)", "tensor(uint4)",        "tensor(int4)",
      "tensor(float4e2m1)"};
  return all_non_complex_tensor_types_ir11;
}

const std::vector<std::string>& OpSchema::all_tensor_types_ir12() {
  static const std::vector<std::string> all_tensor_types_ir12 = {
      "tensor(uint8)",        "tensor(uint16)",         "tensor(uint32)",     "tensor(uint64)",
      "tensor(int8)",         "tensor(int16)",          "tensor(int32)",      "tensor(int64)",
      "tensor(bfloat16)",     "tensor(float16)",        "tensor(float)",      "tensor(double)",
      "tensor(string)",       "tensor(bool)",           "tensor(complex64)",  "tensor(complex128)",
      "tensor(float8e4m3fn)", "tensor(float8e4m3fnuz)", "tensor(float8e5m2)", "tensor(float8e5m2fnuz)",
      "tensor(uint4)",        "tensor(int4)",           "tensor(float4e2m1)", "tensor(float8e8m0)"};
  return all_tensor_types_ir12;
}

const std::vector<std::string>& OpSchema::all_non_complex_tensor_types_ir12() {
  static const std::vector<std::string> all_non_complex_tensor_types_ir12 = {
      "tensor(uint8)",      "tensor(uint16)",         "tensor(uint32)",       "tensor(uint64)",
      "tensor(int8)",       "tensor(int16)",          "tensor(int32)",        "tensor(int64)",
      "tensor(bfloat16)",   "tensor(float16)",        "tensor(float)",        "tensor(double)",
      "tensor(string)",     "tensor(bool)",           "tensor(float8e4m3fn)", "tensor(float8e4m3fnuz)",
      "tensor(float8e5m2)", "tensor(float8e5m2fnuz)", "tensor(uint4)",        "tensor(int4)",
      "tensor(float4e2m1)", "tensor(float8e8m0)"};
  return all_non_complex_tensor_types_ir12;
}

const std::vector<std::string>& OpSchema::all_tensor_types_ir13() {
  static const std::vector<std::string> all_tensor_types_ir13 = {"tensor(uint8)",        "tensor(uint16)",
                                                                 "tensor(uint32)",       "tensor(uint64)",
                                                                 "tensor(int8)",         "tensor(int16)",
                                                                 "tensor(int32)",        "tensor(int64)",
                                                                 "tensor(bfloat16)",     "tensor(float16)",
                                                                 "tensor(float)",        "tensor(double)",
                                                                 "tensor(string)",       "tensor(bool)",
                                                                 "tensor(complex64)",    "tensor(complex128)",
                                                                 "tensor(float8e4m3fn)", "tensor(float8e4m3fnuz)",
                                                                 "tensor(float8e5m2)",   "tensor(float8e5m2fnuz)",
                                                                 "tensor(uint4)",        "tensor(int4)",
                                                                 "tensor(float4e2m1)",   "tensor(float8e8m0)",
                                                                 "tensor(uint2)",        "tensor(int2)"};
  return all_tensor_types_ir13;
}

const std::vector<std::string>& OpSchema::all_non_complex_tensor_types_ir13() {
  static const std::vector<std::string> all_non_complex_tensor_types_ir13 = {
      "tensor(uint8)",      "tensor(uint16)",         "tensor(uint32)",       "tensor(uint64)",
      "tensor(int8)",       "tensor(int16)",          "tensor(int32)",        "tensor(int64)",
      "tensor(bfloat16)",   "tensor(float16)",        "tensor(float)",        "tensor(double)",
      "tensor(string)",     "tensor(bool)",           "tensor(float8e4m3fn)", "tensor(float8e4m3fnuz)",
      "tensor(float8e5m2)", "tensor(float8e5m2fnuz)", "tensor(uint4)",        "tensor(int4)",
      "tensor(float4e2m1)", "tensor(float8e8m0)",     "tensor(uint2)",        "tensor(int2)"};
  return all_non_complex_tensor_types_ir13;
}

const std::vector<std::string>& OpSchema::all_tensor_sequence_types() {
  static const std::vector<std::string> all_tensor_sequence_types = {
      "seq(tensor(uint8))",
      "seq(tensor(uint16))",
      "seq(tensor(uint32))",
      "seq(tensor(uint64))",
      "seq(tensor(int8))",
      "seq(tensor(int16))",
      "seq(tensor(int32))",
      "seq(tensor(int64))",
      "seq(tensor(float16))",
      "seq(tensor(float))",
      "seq(tensor(double))",
      "seq(tensor(string))",
      "seq(tensor(bool))",
      "seq(tensor(complex64))",
      "seq(tensor(complex128))"};
  return all_tensor_sequence_types;
}

const std::vector<std::string>& OpSchema::all_tensor_sequence_types_ir4() {
  static const std::vector<std::string> all_tensor_sequence_types_ir4 = {
      "seq(tensor(uint8))",
      "seq(tensor(uint16))",
      "seq(tensor(uint32))",
      "seq(tensor(uint64))",
      "seq(tensor(int8))",
      "seq(tensor(int16))",
      "seq(tensor(int32))",
      "seq(tensor(int64))",
      "seq(tensor(bfloat16))",
      "seq(tensor(float16))",
      "seq(tensor(float))",
      "seq(tensor(double))",
      "seq(tensor(string))",
      "seq(tensor(bool))",
      "seq(tensor(complex64))",
      "seq(tensor(complex128))"};
  return all_tensor_sequence_types_ir4;
}

const std::vector<std::string>& OpSchema::all_tensor_sequence_types_ir9() {
  static const std::vector<std::string> all_tensor_sequence_types_ir9 = {
      "seq(tensor(uint8))",      "seq(tensor(uint16))",        "seq(tensor(uint32))",
      "seq(tensor(uint64))",     "seq(tensor(int8))",          "seq(tensor(int16))",
      "seq(tensor(int32))",      "seq(tensor(int64))",         "seq(tensor(bfloat16))",
      "seq(tensor(float16))",    "seq(tensor(float))",         "seq(tensor(double))",
      "seq(tensor(string))",     "seq(tensor(bool))",          "seq(tensor(complex64))",
      "seq(tensor(complex128))", "seq(tensor(float8e4m3fn))",  "seq(tensor(float8e4m3fnuz))",
      "seq(tensor(float8e5m2))", "seq(tensor(float8e5m2fnuz))"};
  return all_tensor_sequence_types_ir9;
}

const std::vector<std::string>& OpSchema::all_tensor_sequence_types_ir10() {
  static const std::vector<std::string> all_tensor_sequence_types_ir10 = {
      "seq(tensor(uint8))",      "seq(tensor(uint16))",         "seq(tensor(uint32))",
      "seq(tensor(uint64))",     "seq(tensor(int8))",           "seq(tensor(int16))",
      "seq(tensor(int32))",      "seq(tensor(int64))",          "seq(tensor(bfloat16))",
      "seq(tensor(float16))",    "seq(tensor(float))",          "seq(tensor(double))",
      "seq(tensor(string))",     "seq(tensor(bool))",           "seq(tensor(complex64))",
      "seq(tensor(complex128))", "seq(tensor(float8e4m3fn))",   "seq(tensor(float8e4m3fnuz))",
      "seq(tensor(float8e5m2))", "seq(tensor(float8e5m2fnuz))", "seq(tensor(uint4))",
      "seq(tensor(int4))"};
  return all_tensor_sequence_types_ir10;
}

const std::vector<std::string>& OpSchema::all_tensor_sequence_types_ir11() {
  static const std::vector<std::string> all_tensor_sequence_types_ir11 = {
      "seq(tensor(uint8))",      "seq(tensor(uint16))",         "seq(tensor(uint32))",
      "seq(tensor(uint64))",     "seq(tensor(int8))",           "seq(tensor(int16))",
      "seq(tensor(int32))",      "seq(tensor(int64))",          "seq(tensor(bfloat16))",
      "seq(tensor(float16))",    "seq(tensor(float))",          "seq(tensor(double))",
      "seq(tensor(string))",     "seq(tensor(bool))",           "seq(tensor(complex64))",
      "seq(tensor(complex128))", "seq(tensor(float8e4m3fn))",   "seq(tensor(float8e4m3fnuz))",
      "seq(tensor(float8e5m2))", "seq(tensor(float8e5m2fnuz))", "seq(tensor(uint4))",
      "seq(tensor(int4))",       "seq(tensor(float4e2m1))"};
  return all_tensor_sequence_types_ir11;
}

const std::vector<std::string>& OpSchema::all_tensor_sequence_types_ir12() {
  static const std::vector<std::string> all_tensor_sequence_types_ir12 = {
      "seq(tensor(uint8))",      "seq(tensor(uint16))",         "seq(tensor(uint32))",
      "seq(tensor(uint64))",     "seq(tensor(int8))",           "seq(tensor(int16))",
      "seq(tensor(int32))",      "seq(tensor(int64))",          "seq(tensor(bfloat16))",
      "seq(tensor(float16))",    "seq(tensor(float))",          "seq(tensor(double))",
      "seq(tensor(string))",     "seq(tensor(bool))",           "seq(tensor(complex64))",
      "seq(tensor(complex128))", "seq(tensor(float8e4m3fn))",   "seq(tensor(float8e4m3fnuz))",
      "seq(tensor(float8e5m2))", "seq(tensor(float8e5m2fnuz))", "seq(tensor(uint4))",
      "seq(tensor(int4))",       "seq(tensor(float4e2m1))",     "seq(tensor(float8e8m0))"};
  return all_tensor_sequence_types_ir12;
}

const std::vector<std::string>& OpSchema::all_tensor_sequence_types_ir13() {
  static const std::vector<std::string> all_tensor_sequence_types_ir13 = {
      "seq(tensor(uint8))",        "seq(tensor(uint16))",
      "seq(tensor(uint32))",       "seq(tensor(uint64))",
      "seq(tensor(int8))",         "seq(tensor(int16))",
      "seq(tensor(int32))",        "seq(tensor(int64))",
      "seq(tensor(bfloat16))",     "seq(tensor(float16))",
      "seq(tensor(float))",        "seq(tensor(double))",
      "seq(tensor(string))",       "seq(tensor(bool))",
      "seq(tensor(complex64))",    "seq(tensor(complex128))",
      "seq(tensor(float8e4m3fn))", "seq(tensor(float8e4m3fnuz))",
      "seq(tensor(float8e5m2))",   "seq(tensor(float8e5m2fnuz))",
      "seq(tensor(uint4))",        "seq(tensor(int4))",
      "seq(tensor(float4e2m1))",   "seq(tensor(float8e8m0))",
      "seq(tensor(uint2))",        "seq(tensor(int2))"};
  return all_tensor_sequence_types_ir13;
}

const std::vector<std::string>& OpSchema::all_optional_types() {
  static const std::vector<std::string> all_optional_types = {
      "optional(seq(tensor(uint8)))",  "optional(seq(tensor(uint16)))",    "optional(seq(tensor(uint32)))",
      "optional(seq(tensor(uint64)))", "optional(seq(tensor(int8)))",      "optional(seq(tensor(int16)))",
      "optional(seq(tensor(int32)))",  "optional(seq(tensor(int64)))",     "optional(seq(tensor(float16)))",
      "optional(seq(tensor(float)))",  "optional(seq(tensor(double)))",    "optional(seq(tensor(string)))",
      "optional(seq(tensor(bool)))",   "optional(seq(tensor(complex64)))", "optional(seq(tensor(complex128)))",
      "optional(tensor(uint8))",       "optional(tensor(uint16))",         "optional(tensor(uint32))",
      "optional(tensor(uint64))",      "optional(tensor(int8))",           "optional(tensor(int16))",
      "optional(tensor(int32))",       "optional(tensor(int64))",          "optional(tensor(float16))",
      "optional(tensor(float))",       "optional(tensor(double))",         "optional(tensor(string))",
      "optional(tensor(bool))",        "optional(tensor(complex64))",      "optional(tensor(complex128))"};
  return all_optional_types;
}

const std::vector<std::string>& OpSchema::all_optional_types_ir4() {
  static const std::vector<std::string> all_optional_types = {
      "optional(seq(tensor(uint8)))",      "optional(seq(tensor(uint16)))", "optional(seq(tensor(uint32)))",
      "optional(seq(tensor(uint64)))",     "optional(seq(tensor(int8)))",   "optional(seq(tensor(int16)))",
      "optional(seq(tensor(int32)))",      "optional(seq(tensor(int64)))",  "optional(seq(tensor(bfloat16)))",
      "optional(seq(tensor(float16)))",    "optional(seq(tensor(float)))",  "optional(seq(tensor(double)))",
      "optional(seq(tensor(string)))",     "optional(seq(tensor(bool)))",   "optional(seq(tensor(complex64)))",
      "optional(seq(tensor(complex128)))", "optional(tensor(uint8))",       "optional(tensor(uint16))",
      "optional(tensor(uint32))",          "optional(tensor(uint64))",      "optional(tensor(int8))",
      "optional(tensor(int16))",           "optional(tensor(int32))",       "optional(tensor(int64))",
      "optional(tensor(bfloat16))",        "optional(tensor(float16))",     "optional(tensor(float))",
      "optional(tensor(double))",          "optional(tensor(string))",      "optional(tensor(bool))",
      "optional(tensor(complex64))",       "optional(tensor(complex128))"};
  return all_optional_types;
}

const std::vector<std::string>& OpSchema::all_optional_types_ir9() {
  static const std::vector<std::string> all_optional_types = {
      "optional(seq(tensor(uint8)))",      "optional(seq(tensor(uint16)))", "optional(seq(tensor(uint32)))",
      "optional(seq(tensor(uint64)))",     "optional(seq(tensor(int8)))",   "optional(seq(tensor(int16)))",
      "optional(seq(tensor(int32)))",      "optional(seq(tensor(int64)))",  "optional(seq(tensor(bfloat16)))",
      "optional(seq(tensor(float16)))",    "optional(seq(tensor(float)))",  "optional(seq(tensor(double)))",
      "optional(seq(tensor(string)))",     "optional(seq(tensor(bool)))",   "optional(seq(tensor(complex64)))",
      "optional(seq(tensor(complex128)))", "optional(tensor(uint8))",       "optional(tensor(uint16))",
      "optional(tensor(uint32))",          "optional(tensor(uint64))",      "optional(tensor(int8))",
      "optional(tensor(int16))",           "optional(tensor(int32))",       "optional(tensor(int64))",
      "optional(tensor(bfloat16))",        "optional(tensor(float16))",     "optional(tensor(float))",
      "optional(tensor(double))",          "optional(tensor(string))",      "optional(tensor(bool))",
      "optional(tensor(complex64))",       "optional(tensor(complex128))",  "optional(tensor(float8e4m3fn))",
      "optional(tensor(float8e4m3fnuz))",  "optional(tensor(float8e5m2))",  "optional(tensor(float8e5m2fnuz))"};
  return all_optional_types;
}

const std::vector<std::string>& OpSchema::all_optional_types_ir10() {
  static const std::vector<std::string> all_optional_types = {
      "optional(seq(tensor(uint8)))",      "optional(seq(tensor(uint16)))", "optional(seq(tensor(uint32)))",
      "optional(seq(tensor(uint64)))",     "optional(seq(tensor(int8)))",   "optional(seq(tensor(int16)))",
      "optional(seq(tensor(int32)))",      "optional(seq(tensor(int64)))",  "optional(seq(tensor(bfloat16)))",
      "optional(seq(tensor(float16)))",    "optional(seq(tensor(float)))",  "optional(seq(tensor(double)))",
      "optional(seq(tensor(string)))",     "optional(seq(tensor(bool)))",   "optional(seq(tensor(complex64)))",
      "optional(seq(tensor(complex128)))", "optional(tensor(uint8))",       "optional(tensor(uint16))",
      "optional(tensor(uint32))",          "optional(tensor(uint64))",      "optional(tensor(int8))",
      "optional(tensor(int16))",           "optional(tensor(int32))",       "optional(tensor(int64))",
      "optional(tensor(bfloat16))",        "optional(tensor(float16))",     "optional(tensor(float))",
      "optional(tensor(double))",          "optional(tensor(string))",      "optional(tensor(bool))",
      "optional(tensor(complex64))",       "optional(tensor(complex128))",  "optional(tensor(float8e4m3fn))",
      "optional(tensor(float8e4m3fnuz))",  "optional(tensor(float8e5m2))",  "optional(tensor(float8e5m2fnuz))",
      "optional(tensor(uint4))",           "optional(tensor(int4))"};
  return all_optional_types;
}

const std::vector<std::string>& OpSchema::all_optional_types_ir11() {
  static const std::vector<std::string> all_optional_types = {
      "optional(seq(tensor(uint8)))",      "optional(seq(tensor(uint16)))", "optional(seq(tensor(uint32)))",
      "optional(seq(tensor(uint64)))",     "optional(seq(tensor(int8)))",   "optional(seq(tensor(int16)))",
      "optional(seq(tensor(int32)))",      "optional(seq(tensor(int64)))",  "optional(seq(tensor(bfloat16)))",
      "optional(seq(tensor(float16)))",    "optional(seq(tensor(float)))",  "optional(seq(tensor(double)))",
      "optional(seq(tensor(string)))",     "optional(seq(tensor(bool)))",   "optional(seq(tensor(complex64)))",
      "optional(seq(tensor(complex128)))", "optional(tensor(uint8))",       "optional(tensor(uint16))",
      "optional(tensor(uint32))",          "optional(tensor(uint64))",      "optional(tensor(int8))",
      "optional(tensor(int16))",           "optional(tensor(int32))",       "optional(tensor(int64))",
      "optional(tensor(bfloat16))",        "optional(tensor(float16))",     "optional(tensor(float))",
      "optional(tensor(double))",          "optional(tensor(string))",      "optional(tensor(bool))",
      "optional(tensor(complex64))",       "optional(tensor(complex128))",  "optional(tensor(float8e4m3fn))",
      "optional(tensor(float8e4m3fnuz))",  "optional(tensor(float8e5m2))",  "optional(tensor(float8e5m2fnuz))",
      "optional(tensor(uint4))",           "optional(tensor(int4))",        "optional(tensor(float4e2m1))"};
  return all_optional_types;
}

const std::vector<std::string>& OpSchema::all_optional_types_ir12() {
  static const std::vector<std::string> all_optional_types = {
      "optional(seq(tensor(uint8)))",      "optional(seq(tensor(uint16)))", "optional(seq(tensor(uint32)))",
      "optional(seq(tensor(uint64)))",     "optional(seq(tensor(int8)))",   "optional(seq(tensor(int16)))",
      "optional(seq(tensor(int32)))",      "optional(seq(tensor(int64)))",  "optional(seq(tensor(bfloat16)))",
      "optional(seq(tensor(float16)))",    "optional(seq(tensor(float)))",  "optional(seq(tensor(double)))",
      "optional(seq(tensor(string)))",     "optional(seq(tensor(bool)))",   "optional(seq(tensor(complex64)))",
      "optional(seq(tensor(complex128)))", "optional(tensor(uint8))",       "optional(tensor(uint16))",
      "optional(tensor(uint32))",          "optional(tensor(uint64))",      "optional(tensor(int8))",
      "optional(tensor(int16))",           "optional(tensor(int32))",       "optional(tensor(int64))",
      "optional(tensor(bfloat16))",        "optional(tensor(float16))",     "optional(tensor(float))",
      "optional(tensor(double))",          "optional(tensor(string))",      "optional(tensor(bool))",
      "optional(tensor(complex64))",       "optional(tensor(complex128))",  "optional(tensor(float8e4m3fn))",
      "optional(tensor(float8e4m3fnuz))",  "optional(tensor(float8e5m2))",  "optional(tensor(float8e5m2fnuz))",
      "optional(tensor(uint4))",           "optional(tensor(int4))",        "optional(tensor(float4e2m1))",
      "optional(tensor(float8e8m0))"};
  return all_optional_types;
}

const std::vector<std::string>& OpSchema::all_optional_types_ir13() {
  static const std::vector<std::string> all_optional_types = {
      "optional(seq(tensor(uint8)))",      "optional(seq(tensor(uint16)))", "optional(seq(tensor(uint32)))",
      "optional(seq(tensor(uint64)))",     "optional(seq(tensor(int8)))",   "optional(seq(tensor(int16)))",
      "optional(seq(tensor(int32)))",      "optional(seq(tensor(int64)))",  "optional(seq(tensor(bfloat16)))",
      "optional(seq(tensor(float16)))",    "optional(seq(tensor(float)))",  "optional(seq(tensor(double)))",
      "optional(seq(tensor(string)))",     "optional(seq(tensor(bool)))",   "optional(seq(tensor(complex64)))",
      "optional(seq(tensor(complex128)))", "optional(tensor(uint8))",       "optional(tensor(uint16))",
      "optional(tensor(uint32))",          "optional(tensor(uint64))",      "optional(tensor(int8))",
      "optional(tensor(int16))",           "optional(tensor(int32))",       "optional(tensor(int64))",
      "optional(tensor(bfloat16))",        "optional(tensor(float16))",     "optional(tensor(float))",
      "optional(tensor(double))",          "optional(tensor(string))",      "optional(tensor(bool))",
      "optional(tensor(complex64))",       "optional(tensor(complex128))",  "optional(tensor(float8e4m3fn))",
      "optional(tensor(float8e4m3fnuz))",  "optional(tensor(float8e5m2))",  "optional(tensor(float8e5m2fnuz))",
      "optional(tensor(uint4))",           "optional(tensor(int4))",        "optional(tensor(float4e2m1))",
      "optional(tensor(float8e8m0))",      "optional(tensor(uint2))",       "optional(tensor(int2))"};
  return all_optional_types;
}

void OpSchema::Finalize() {
#define ENFORCE(x)                                                                                      \
  do {                                                                                                  \
    if (!(x))                                                                                           \
      ONNX_THROW_EX(std::logic_error("ONNX Schema " + name_ + ": failed validating the check: " + #x)); \
  } while (0)

  // Calculate min/max number of inputs.
  // <Min number of inputs> = <number of "single" inputs> + <number of
  // "optional" but not trailing inputs>. <Max number of inputs> = <number of
  // all inputs or std::numeric_limits<int>::max() (if the last input is
  // variadic).

  max_input_ = 0;
  min_input_ = 0;
  min_output_ = 0;
  max_output_ = 0;

  // Flag indicates whether an optional input is trailing one (there's no single
  // or variadic input behind).
  for (size_t i = 0; i < inputs_.size(); ++i) {
    switch (inputs_[i].GetOption()) {
      case OpSchema::Single:
        ++max_input_;
        min_input_ = max_input_;
        break;
      case OpSchema::Optional:
        ++max_input_;
        break;
      case OpSchema::Variadic:
        // Only last input formal parameter could be variadic.
        ENFORCE((inputs_.size() - 1) == i);
        min_input_ = max_input_ + inputs_[i].GetMinArity();
        max_input_ = std::numeric_limits<int>::max();
        break;
    }
  }

  // Calculate min/max number of outputs.
  for (size_t i = 0; i < outputs_.size(); ++i) {
    switch (outputs_[i].GetOption()) {
      case OpSchema::Single:
        ++max_output_;
        min_output_ = max_output_;
        break;
      case OpSchema::Optional:
        ++max_output_;
        break;
      case OpSchema::Variadic:
        // Only last output formal parameter could be variadic.
        ENFORCE((outputs_.size() - 1) == i);
        min_output_ = max_output_ + outputs_[i].GetMinArity();
        max_output_ = std::numeric_limits<int>::max();
        break;
    }
  }

  // all inputs and outputs have names
  for (const auto& it : inputs_) {
    ENFORCE(!(it.GetName().empty()));
  }
  for (const auto& it : outputs_) {
    ENFORCE(!(it.GetName().empty()));
  }

  ParseAndSetTypes(&inputs_);
  ParseAndSetTypes(&outputs_);

  for (auto& func : opset_version_to_function_body_) {
    BuildFunction(*func.second);
  }
}

OpSchema::NodeDeterminism OpSchema::GetNodeDeterminism() const {
  if (node_determinism_ == NodeDeterminism::Unknown) {
    for (const auto& attr : attributes()) {
      switch (attr.second.type) {
        case AttributeProto::GRAPH:
        case AttributeProto::GRAPHS:
          return NodeDeterminism::NonDeterministic;
        default:
          break;
      }
    }

    if (HasContextDependentFunction()) {
      return NodeDeterminism::Unknown;
    } else if (const FunctionProto* func_proto = GetFunction(); func_proto) {
      const OpSchemaRegistry& reg = *OpSchemaRegistry::Instance();
      std::unordered_map<std::string, int> domain_to_opset_version;
      for (const auto& opset : func_proto->opset_import()) {
        domain_to_opset_version[opset.domain()] = opset.version();
      }
      for (const NodeProto& n : func_proto->node()) {
        const int opset = domain_to_opset_version[n.domain()];
        const OpSchema* sch = reg.GetSchema(n.op_type(), opset, n.domain());
        if (!sch) {
          return NodeDeterminism::Unknown;
        }
        switch (sch->GetNodeDeterminism()) {
          case NodeDeterminism::NonDeterministic:
            return NodeDeterminism::NonDeterministic;
          case NodeDeterminism::Unknown:
            return NodeDeterminism::Unknown;
          default:
            break;
        }
      }
    }
    return NodeDeterminism::Deterministic;
  }
  return node_determinism_;
}

OpSchema& OpSchema::SetNodeDeterminism(NodeDeterminism node_determinism) {
  this->node_determinism_ = node_determinism;
  return *this;
}

std::ostream& operator<<(std::ostream& out, const OpSchema& schema) {
  if (!schema.attributes_.empty()) {
    out << "Attributes:" << '\n';
    for (const auto& pair : schema.attributes_) {
      out << "  " << pair.second.name << " : " << pair.second.description << '\n';
    }
  }
  if (schema.max_input_ > 0) {
    out << "Inputs:" << '\n';
    if (!schema.inputs_.empty()) {
      for (size_t i = 0; i < schema.inputs_.size(); ++i) {
        const auto& p = schema.inputs_[i];
        const auto& name = p.GetName();
        const auto& description = p.GetDescription();
        const auto& type_str = p.GetTypeStr();
        out << "  " << i << ", " << (!name.empty() ? name : "(unnamed)") << " : "
            << (!description.empty() ? description : "(no doc)") << " : "
            << (!type_str.empty() ? type_str : "(no type)") << '\n';
      }
    } else {
      out << "  (no explicit description available)" << '\n';
    }
  }
  if (schema.max_output_ > 0) {
    out << "Outputs:" << '\n';
    if (!schema.outputs_.empty()) {
      for (size_t i = 0; i < schema.outputs_.size(); ++i) {
        const auto& p = schema.outputs_[i];
        const auto& name = p.GetName();
        const auto& description = p.GetDescription();
        const auto& type_str = p.GetTypeStr();
        out << "  " << i << ", " << (!name.empty() ? name : "(unnamed)") << " : "
            << (!description.empty() ? description : "(no doc)") << " : "
            << (!type_str.empty() ? type_str : "(no type)") << '\n';
      }
    } else {
      out << "  (no explicit description available)" << '\n';
    }
  }
  out << '\n';
  if (schema.doc()) {
    out << schema.doc();
  } else {
    out << "(no documentation yet)" << '\n';
  }
  out << '\n';
  if (schema.line_) {
    out << "Defined at " << schema.file_ << ":" << schema.line_ << '\n';
  }
  return out;
}

OpSchemaRegistry::DomainToVersionRange& OpSchemaRegistry::DomainToVersionRange::Instance() {
  static DomainToVersionRange domain_to_version_range;
  return domain_to_version_range;
};

// Private method used by OpSchemaRegisterOnce and OpSchemaRegistry::map()
OpName_Domain_Version_Schema_Map& OpSchemaRegistry::GetMapWithoutEnsuringRegistration() {
  static OpName_Domain_Version_Schema_Map map;
  return map;
}

OpName_Domain_Version_Schema_Map& OpSchemaRegistry::map() {
  auto& map = GetMapWithoutEnsuringRegistration();

  // The following class is used to register operators the
  // first time this method is called, in a thread-safe fashion.
  class SchemasRegisterer {
   public:
    SchemasRegisterer() {
      // In debug builds, the number of schema registered in this constructor
      // is compared against the number of calls to schema registration macros.
#ifndef NDEBUG
      size_t dbg_initial_schema_count = GetRegisteredSchemaCount();
#endif

      RegisterOnnxOperatorSetSchema();

#ifdef ONNX_ML
      RegisterOnnxMLOperatorSetSchema();
#endif

      // Invoke register of training operators.
      RegisterOnnxTrainingOperatorSetSchema();

      // Invoke register of experimental operators.
      RegisterOnnxPreviewOperatorSetSchema();

#ifndef NDEBUG
      size_t dbg_registered_schema_count = GetRegisteredSchemaCount() - dbg_initial_schema_count;
      // Check enabled only if schemas for all opset versions are loaded
      if (OpSchemaRegistry::Instance()->GetLoadedSchemaVersion() == 0) {
        ONNX_ASSERTM(
            dbg_registered_schema_count == ONNX_DBG_GET_COUNT_IN_OPSETS(),
            "%u schema were exposed from operator sets and automatically placed into the static registry.  "
            "%u were expected based on calls to registration macros. Operator set functions may need to be updated.",
            dbg_registered_schema_count,
            ONNX_DBG_GET_COUNT_IN_OPSETS());
      }
#endif
    }

   private:
#ifndef NDEBUG
    static size_t GetRegisteredSchemaCount() {
      size_t count = 0;
      for (auto& x : GetMapWithoutEnsuringRegistration()) {
        for (auto& y : x.second) {
          count += y.second.size();
        }
      }
      return count;
    }
#endif
  };

#ifndef __ONNX_DISABLE_STATIC_REGISTRATION
  static SchemasRegisterer schemasRegisterer;
#endif

  return map;
}

size_t ReplaceAll(std::string& s, const char* from, const char* to) {
  size_t numReplaced = 0;
  std::string::size_type lenFrom = std::strlen(from);
  std::string::size_type lenTo = std::strlen(to);
  for (std::string::size_type pos = s.find(from); pos != std::string::npos; pos = s.find(from, pos + lenTo)) {
    s.replace(pos, lenFrom, to);
    numReplaced++;
  }
  return numReplaced;
}

bool IsOnnxStaticRegistrationDisabled() {
#ifdef __ONNX_DISABLE_STATIC_REGISTRATION
  return true;
#else
  return false;
#endif
}

} // namespace ONNX_NAMESPACE
