// Copyright (c) ONNX Project Contributors

/*
 * SPDX-License-Identifier: Apache-2.0
 */

#include "gtest/gtest.h"
#include "onnx/defs/operator_sets.h"
#include "onnx/defs/schema.h"

using namespace ONNX_NAMESPACE;

namespace ONNX_NAMESPACE {
namespace Test {

TEST(SchemaRegistrationTest, DisabledOnnxStaticRegistrationAPICall) {
#ifdef __ONNX_DISABLE_STATIC_REGISTRATION
  EXPECT_TRUE(IsOnnxStaticRegistrationDisabled());
#else
  EXPECT_FALSE(IsOnnxStaticRegistrationDisabled());
#endif
}

// Schema of all versions are registered by default
// Further schema manipulation expects to be error-free
TEST(SchemaRegistrationTest, RegisterAllByDefaultAndManipulateSchema) {
#ifndef __ONNX_DISABLE_STATIC_REGISTRATION

  // Expects all opset registered by default
  EXPECT_TRUE(OpSchemaRegistry::Instance()->GetLoadedSchemaVersion() == 0);

  // Should find schema for all versions
  EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Add", 1));
  EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Add", 6));
  EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Add", 7));
  EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Add", 13));

  // Clear all opset schema registration
  DeregisterOnnxOperatorSetSchema();

  // Should not find any opset
  EXPECT_EQ(nullptr, OpSchemaRegistry::Schema("Add"));

  // Register all opset versions
  RegisterOnnxOperatorSetSchema();

  // Should find all opset
  EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Add"));
#endif
}

// By default ONNX registers all opset versions and selective schema loading cannot be tested
// So these tests are run only when static registration is disabled
TEST(SchemaRegistrationTest, RegisterAndDeregisterAllOpsetSchemaVersion) {
#ifdef __ONNX_DISABLE_STATIC_REGISTRATION

  // Clear all opset schema registration
  DeregisterOnnxOperatorSetSchema();
  EXPECT_TRUE(OpSchemaRegistry::Instance()->GetLoadedSchemaVersion() == -1);

  // Should not find schema for any op
  EXPECT_EQ(nullptr, OpSchemaRegistry::Schema("Acos"));
  EXPECT_EQ(nullptr, OpSchemaRegistry::Schema("Add"));
  EXPECT_EQ(nullptr, OpSchemaRegistry::Schema("Trilu"));

  // Register all opset versions
  RegisterOnnxOperatorSetSchema(0);
  EXPECT_TRUE(OpSchemaRegistry::Instance()->GetLoadedSchemaVersion() == 0);

  // Should find schema for all ops. Available versions are:
  // Acos-7
  // Add-1,6,7,13,14
  // Trilu-14
  auto schema = OpSchemaRegistry::Schema("Acos", 7);
  EXPECT_NE(nullptr, schema);
  EXPECT_EQ(schema->SinceVersion(), 7);

  schema = OpSchemaRegistry::Schema("Add", 14);
  EXPECT_NE(nullptr, schema);
  EXPECT_EQ(schema->SinceVersion(), 14);

  schema = OpSchemaRegistry::Schema("Trilu");
  EXPECT_NE(nullptr, schema);
  EXPECT_EQ(schema->SinceVersion(), 14);

  EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Add", 1));
  EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Add", 6));
  EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Add", 7));
  EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Add", 13));

  // Clear all opset schema registration
  DeregisterOnnxOperatorSetSchema();
  EXPECT_TRUE(OpSchemaRegistry::Instance()->GetLoadedSchemaVersion() == -1);

  // Should not find schema for any op
  EXPECT_EQ(nullptr, OpSchemaRegistry::Schema("Acos"));
  EXPECT_EQ(nullptr, OpSchemaRegistry::Schema("Add"));
  EXPECT_EQ(nullptr, OpSchemaRegistry::Schema("Trilu"));
#endif
}

TEST(SchemaRegistrationTest, RegisterSpecifiedOpsetSchemaVersion) {
#ifdef __ONNX_DISABLE_STATIC_REGISTRATION
  DeregisterOnnxOperatorSetSchema();
  EXPECT_TRUE(OpSchemaRegistry::Instance()->GetLoadedSchemaVersion() == -1);
  RegisterOnnxOperatorSetSchema(13);
  EXPECT_TRUE(OpSchemaRegistry::Instance()->GetLoadedSchemaVersion() == 13);

  auto opSchema = OpSchemaRegistry::Schema("Add");
  EXPECT_NE(nullptr, opSchema);
  EXPECT_EQ(opSchema->SinceVersion(), 13);

  // Should not find opset 12
  opSchema = OpSchemaRegistry::Schema("Add", 12);
  EXPECT_EQ(nullptr, opSchema);

  // Should not find opset 14
  opSchema = OpSchemaRegistry::Schema("Trilu");
  EXPECT_EQ(nullptr, opSchema);

  // Acos-7 is the latest Acos before specified 13
  opSchema = OpSchemaRegistry::Schema("Acos", 13);
  EXPECT_NE(nullptr, opSchema);
  EXPECT_EQ(opSchema->SinceVersion(), 7);
#endif
}

// Register opset-11, then opset-14
// Expects Reg(11, 14) == Reg(11) U Reg(14)
TEST(SchemaRegistrationTest, RegisterMultipleOpsetSchemaVersions_UpgradeVersion) {
#ifdef __ONNX_DISABLE_STATIC_REGISTRATION
  DeregisterOnnxOperatorSetSchema();
  EXPECT_TRUE(OpSchemaRegistry::Instance()->GetLoadedSchemaVersion() == -1);

  // Register opset 11
  RegisterOnnxOperatorSetSchema(11);
  EXPECT_TRUE(OpSchemaRegistry::Instance()->GetLoadedSchemaVersion() == 11);
  // Register opset 14
  // Do not fail on duplicate schema registration request
  RegisterOnnxOperatorSetSchema(14, false);
  EXPECT_TRUE(OpSchemaRegistry::Instance()->GetLoadedSchemaVersion() == 14);

  // Acos-7 is the latest before/at opset 11 and 14
  auto opSchema = OpSchemaRegistry::Schema("Acos");
  EXPECT_NE(nullptr, opSchema);
  EXPECT_EQ(opSchema->SinceVersion(), 7);

  // Add-7 is the latest before/at opset 11
  // Add-14 is the latest before/at opset 14
  // Should find both Add-7,14
  EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Add", 7));
  EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Add", 14));

  // Should find the max version 14
  opSchema = OpSchemaRegistry::Schema("Add");
  EXPECT_NE(nullptr, opSchema);
  EXPECT_EQ(opSchema->SinceVersion(), 14);

  // Should find Add-7 as the max version <=13
  opSchema = OpSchemaRegistry::Schema("Add", 13);
  EXPECT_NE(nullptr, opSchema);
  EXPECT_EQ(opSchema->SinceVersion(), 7);

  // Should find opset 14
  opSchema = OpSchemaRegistry::Schema("Trilu");
  EXPECT_NE(nullptr, opSchema);
  EXPECT_EQ(opSchema->SinceVersion(), 14);
#endif
}

// Register opset-14, then opset-11
// Expects Reg(14, 11) == Reg(11) U Reg(14)
TEST(SchemaRegistrationTest, RegisterMultipleOpsetSchemaVersions_DowngradeVersion) {
#ifdef __ONNX_DISABLE_STATIC_REGISTRATION
  DeregisterOnnxOperatorSetSchema();
  EXPECT_TRUE(OpSchemaRegistry::Instance()->GetLoadedSchemaVersion() == -1);

  // Register opset 14
  RegisterOnnxOperatorSetSchema(14);
  EXPECT_TRUE(OpSchemaRegistry::Instance()->GetLoadedSchemaVersion() == 14);
  // Register opset 11
  // Do not fail on duplicate schema registration request
  RegisterOnnxOperatorSetSchema(11, false);
  EXPECT_TRUE(OpSchemaRegistry::Instance()->GetLoadedSchemaVersion() == 11);

  // Acos-7 is the latest before/at opset 11 and 14
  auto opSchema = OpSchemaRegistry::Schema("Acos");
  EXPECT_NE(nullptr, opSchema);
  EXPECT_EQ(opSchema->SinceVersion(), 7);

  // Add-7 is the latest before/at opset 11
  // Add-14 is the latest before/at opset 14
  // Should find both Add-7,14
  EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Add", 7));
  EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Add", 14));

  // Should find the max version 14
  opSchema = OpSchemaRegistry::Schema("Add");
  EXPECT_NE(nullptr, opSchema);
  EXPECT_EQ(opSchema->SinceVersion(), 14);

  // Should find Add-7 as the max version <=13
  opSchema = OpSchemaRegistry::Schema("Add", 13);
  EXPECT_NE(nullptr, opSchema);
  EXPECT_EQ(opSchema->SinceVersion(), 7);

  // Should find opset 14
  opSchema = OpSchemaRegistry::Schema("Trilu");
  EXPECT_NE(nullptr, opSchema);
  EXPECT_EQ(opSchema->SinceVersion(), 14);
#endif
}

// Register opset-11, then all versions
// Expects no error
TEST(SchemaRegistrationTest, RegisterSpecificThenAllVersion) {
#ifdef __ONNX_DISABLE_STATIC_REGISTRATION
  DeregisterOnnxOperatorSetSchema();
  EXPECT_TRUE(OpSchemaRegistry::Instance()->GetLoadedSchemaVersion() == -1);

  // Register opset 11
  RegisterOnnxOperatorSetSchema(11);
  EXPECT_TRUE(OpSchemaRegistry::Instance()->GetLoadedSchemaVersion() == 11);

  // Register all opset versions
  // Do not fail on duplicate schema registration request
  RegisterOnnxOperatorSetSchema(0, false);
  EXPECT_TRUE(OpSchemaRegistry::Instance()->GetLoadedSchemaVersion() == 0);

  // Should find schema for all ops
  EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Acos"));
  EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Add"));
  EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Trilu"));

  // Should find schema for all versions
  EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Add", 1));
  EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Add", 6));
  EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Add", 7));
  EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Add", 13));
#endif
}

// Register all versions, then opset 11
// Expects no error
TEST(SchemaRegistrationTest, RegisterAllThenSpecificVersion) {
#ifdef __ONNX_DISABLE_STATIC_REGISTRATION
  DeregisterOnnxOperatorSetSchema();
  EXPECT_TRUE(OpSchemaRegistry::Instance()->GetLoadedSchemaVersion() == -1);

  // Register all opset versions
  RegisterOnnxOperatorSetSchema(0);
  EXPECT_TRUE(OpSchemaRegistry::Instance()->GetLoadedSchemaVersion() == 0);

  // Register opset 11
  // Do not fail on duplicate schema registration request
  RegisterOnnxOperatorSetSchema(11, false);
  EXPECT_TRUE(OpSchemaRegistry::Instance()->GetLoadedSchemaVersion() == 11);

  // Should find schema for all ops
  EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Acos"));
  EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Add"));
  EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Trilu"));

  // Should find schema for all versions
  EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Add", 1));
  EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Add", 6));
  EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Add", 7));
  EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Add", 13));
#endif
}

} // namespace Test
} // namespace ONNX_NAMESPACE
