// Copyright (c) ONNX Project Contributors

/*
 * SPDX-License-Identifier: Apache-2.0
 */

#include <stdexcept>

#include "gtest/gtest.h"
#include "onnx/checker.h"
#include "onnx/common/constants.h"
#include "onnx/defs/function.h"
#include "onnx/defs/schema.h"

using namespace ONNX_NAMESPACE::checker;

namespace ONNX_NAMESPACE {
namespace Test {

// Utilities. TODO: Turn them into reusable ONNX utilities for use by

static TensorProto ToTensor(double value, TensorProto_DataType elem_type) {
  TensorProto t;
  t.set_data_type(elem_type);
  switch (elem_type) {
    case TensorProto_DataType::TensorProto_DataType_FLOAT:
      t.add_float_data((float)value);
      break;
    case TensorProto_DataType::TensorProto_DataType_DOUBLE:
      t.add_double_data(value);
      break;
    // case TensorProto_DataType::TensorProto_DataType_FLOAT16:
    //   t.add_int32_data(onnxruntime::math::floatToHalf((float)value));
    //   break;
    default:
      assert(false);
  }

  return t;
}

static void BuildNodes(FunctionProto& functionProto, const std::vector<FunctionBodyHelper::NodeDef>& node_defs) {
  for (const auto& node : node_defs) {
    auto* np = functionProto.add_node();

    np->set_op_type(node.op_type);
    for (const auto& inp : node.inputs) {
      np->add_input(inp);
    }
    for (const auto& o : node.outputs) {
      np->add_output(o);
    }
    for (const auto& attr : node.attributes) {
      *(np->add_attribute()) = attr.proto;
    }
  }
}

static bool BuildFunctionProto(
    FunctionProto& functionProto,
    const OpSchema& schema,
    const std::vector<FunctionBodyHelper::NodeDef>& node_defs) {
  BuildNodes(functionProto, node_defs);
  schema.BuildFunction(functionProto);
  return true;
}

// A monomorphic context-dependent function test-case.
static bool
BuildFloatFunctionBody(const FunctionBodyBuildContext& /*ctx*/, const OpSchema& schema, FunctionProto& functionProto) {
  // Create a scalar-tensor constant 2.0 of float type:
  auto two_as_tensor = ToTensor(2.0, TensorProto_DataType::TensorProto_DataType_FLOAT);

  std::vector<FunctionBodyHelper::NodeDef> body{
      // nodes: {outputs, op, inputs, attributes}
      {{"Two"}, "Constant", {}, {{"value", two_as_tensor}}},
      {{"Y"}, "Mul", {"X", "Two"}}};

  return BuildFunctionProto(functionProto, schema, body);
}

static void RegisterCustomFuncFloatSchema() {
  ONNX_NAMESPACE::OpSchema schema;
  schema.SetName("CustomFuncFloat")
      .SetDomain(ONNX_DOMAIN)
      .SinceVersion(12)
      .SetDoc("This operator returns an output tensor that is twice the input tensor.")
      .Input(0, "X", "Input tensor", "T", OpSchema::Single)
      .Output(0, "Y", "Output tensor", "T", OpSchema::Single)
      .TypeConstraint("T", {"tensor(float)"}, "Type of the input and output values")
      .SetNodeDeterminism(OpSchema::NodeDeterminism::Deterministic)
      .SetContextDependentFunctionBodyBuilder(BuildFloatFunctionBody);
  ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce unused(schema);
  (void)unused;
}

// Test for Context dependent function without type context
TEST(FunctionAPITest, ContextDependentFunctionTest) {
  RegisterCustomFuncFloatSchema();

  const auto* schema = OpSchemaRegistry::Schema("CustomFuncFloat", 12, ONNX_DOMAIN);
  EXPECT_TRUE(schema);
  EXPECT_FALSE(schema->HasFunction());
  EXPECT_TRUE(schema->HasContextDependentFunction());

  NodeProto nodeProto;
  nodeProto.set_op_type("CustomFuncFloat");
  nodeProto.add_input("X");
  nodeProto.add_output("Y");

  FunctionBodyBuildContextImpl ctx(nodeProto);
  FunctionProto fnProto;
  EXPECT_TRUE(schema->BuildContextDependentFunction(ctx, fnProto));
  EXPECT_EQ(fnProto.node_size(), 2);

  LexicalScopeContext lexicalScope;
  CheckerContext checkerCtx;
  std::unordered_map<std::string, int> opset_imports({{ONNX_DOMAIN, 12}});
  checkerCtx.set_opset_imports(opset_imports);
  checkerCtx.set_ir_version(7);
  check_function(fnProto, checkerCtx, lexicalScope);
}

// A polymorphic context-dependent function test-case.

static bool
BuildFunctionBody(const FunctionBodyBuildContext& ctx, const OpSchema& schema, FunctionProto& functionProto) {
  // Create a scalar-tensor constant 2.0 of input-type:
  auto* tp = ctx.getInputType(0);
  if ((tp == nullptr) || (!tp->has_tensor_type()))
    return false;
  auto elem_type = (TensorProto_DataType)tp->tensor_type().elem_type();
  auto two_as_tensor = ToTensor(2.0, elem_type);

  std::vector<FunctionBodyHelper::NodeDef> body{
      // nodes: {outputs, op, inputs, attributes}
      {{"Two"}, "Constant", {}, {{"value", two_as_tensor}}},
      {{"Y"}, "Mul", {"X", "Two"}}};

  return BuildFunctionProto(functionProto, schema, body);
}

static void RegisterCustomFunctionSchema() {
  ONNX_NAMESPACE::OpSchema schema;
  schema.SetName("CustomFunction")
      .SetDomain(ONNX_DOMAIN)
      .SinceVersion(12)
      .SetDoc("This operator returns an output tensor that is twice the input tensor.")
      .Input(0, "X", "Input tensor", "T", OpSchema::Single)
      .Output(0, "Y", "Output tensor", "T", OpSchema::Single)
      .TypeConstraint("T", {"tensor(float)", "tensor(double)"}, "Type of the input and output values")
      .SetNodeDeterminism(OpSchema::NodeDeterminism::Deterministic)
      .SetContextDependentFunctionBodyBuilder(BuildFunctionBody);
  ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce unused(schema);
  (void)unused;
}

TEST(FunctionAPITest, VersionedFunctionBodyTest) {
  // This test illustrate issues of ONNX function ops.
  // It is over simplified in that only one primary op (Sub) is used in function body.
  // ONNX opset     1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18
  // MySub:            2                    9                               // MySub function op is created at opset 2.
  //                                                                        // Its semantic is updated at opset 7
  // Body Ideal:       2           6  7     9          13 14    16          // Ideally function body shall be provided
  //                                                                        // each time there is any version bump of
  //                                                                        // used primary ops. It will be more
  //                                                                        // frequent
  //                                                                        // if more primary ops are used.
  // Body Real:        2                    9                   16          // In real life, we seldom add function body
  //                                                                        // due to primary op update
  // Sub:           1              6  7                13 14                // Version bumps of Sub
  // Model:            y  y  y  y  n  n  n  y  y  y  y n  n  n  y  y  y     // Model can(y)/cannot(n) used
  // with opset import version.
  ONNX_NAMESPACE::OpSchema schema_ver2;
  schema_ver2.SetName("MySub")
      .SetDomain(ONNX_DOMAIN)
      .SinceVersion(2)
      .SetDoc("Z = Sub (X, Y)")
      .Input(0, "X", "Input tensor X", "T", OpSchema::Single)
      .Input(1, "Y", "Input tensor Y", "T", OpSchema::Single)
      .Output(0, "Z", "Output tensor Z", "T", OpSchema::Single)
      .TypeConstraint("T", {"tensor(float)", "tensor(double)"}, "Type of the input and output values")
      .FunctionBody(
          R"ONNX(
        {
          Z = Sub (X, Y)
        }
        )ONNX",
          2);

  ONNX_NAMESPACE::OpSchema schema_ver9;
  schema_ver9.SetName("MySub")
      .SetDomain(ONNX_DOMAIN)
      .SinceVersion(9)
      .SetDoc("Z = Sub (X, Y)")
      .Input(0, "X", "Input tensor X", "T", OpSchema::Single)
      .Input(1, "Y", "Input tensor Y", "T", OpSchema::Single)
      .Output(0, "Z", "Output tensor Z", "T", OpSchema::Single)
      .TypeConstraint("T", {"tensor(float)", "tensor(double)"}, "Type of the input and output values")
      .FunctionBody(
          R"ONNX(
        {
          Z = Sub (X, Y)
        }
        )ONNX",
          9)
      .FunctionBody(
          R"ONNX(
        {
          Z = Sub (X, Y)
        }
        )ONNX",
          16);

  ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce unused2(schema_ver2);
  (void)unused2;
  ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce unused9(schema_ver9);
  (void)unused9;

  const auto* schema2 = OpSchemaRegistry::Schema("MySub", 2, ONNX_DOMAIN);
  EXPECT_TRUE(schema2);
  for (int model_opset_import = 2; model_opset_import < 9; model_opset_import++) {
    ONNX_TRY {
      bool validate = true;
      const FunctionProto* function = schema2->GetFunction(model_opset_import, validate);
      if (model_opset_import >= 6) { // function body should be updated at opset 6 where Sub is updated
        ASSERT_TRUE(function == nullptr);
      } else {
        ASSERT_TRUE(function);
      }
    }
    ONNX_CATCH(const std::runtime_error&) {
      ONNX_HANDLE_EXCEPTION(
          [&]() { ASSERT_TRUE(model_opset_import == 6 || model_opset_import == 7 || model_opset_import == 8); });
    }
  }

  const auto* schema9 = OpSchemaRegistry::Schema("MySub", 9, ONNX_DOMAIN);
  EXPECT_TRUE(schema9);
  for (int model_opset_import = 9; model_opset_import < 10; model_opset_import++) {
    ONNX_TRY {
      const FunctionProto* function = schema9->GetFunction(model_opset_import);
      ASSERT_TRUE(function);
    }
    ONNX_CATCH(const std::runtime_error&) {
      ONNX_HANDLE_EXCEPTION(
          [&]() { ASSERT_TRUE(model_opset_import == 13 || model_opset_import == 14 || model_opset_import == 15); });
    }
  }
}

TEST(FunctionAPITest, TypeContextTest) {
  RegisterCustomFunctionSchema();

  const auto* schema = OpSchemaRegistry::Schema("CustomFunction", 12, ONNX_DOMAIN);
  EXPECT_TRUE(schema);
  EXPECT_FALSE(schema->HasFunction());
  EXPECT_TRUE(schema->HasContextDependentFunction());

  NodeProto nodeProto;
  nodeProto.set_op_type("CustomFunction");
  nodeProto.add_input("X");
  nodeProto.add_output("Y");

  TypeProto floatTypeProto;
  floatTypeProto.mutable_tensor_type()->set_elem_type(TensorProto_DataType::TensorProto_DataType_FLOAT);

  FunctionBodyBuildContextImpl ctx(nodeProto, {floatTypeProto});
  FunctionProto fnProto;
  EXPECT_TRUE(schema->BuildContextDependentFunction(ctx, fnProto));
  EXPECT_EQ(fnProto.node_size(), 2);

  LexicalScopeContext lexicalScope;
  CheckerContext checkerCtx;
  std::unordered_map<std::string, int> opset_imports({{ONNX_DOMAIN, 12}});
  checkerCtx.set_opset_imports(opset_imports);
  checkerCtx.set_ir_version(7);
  check_function(fnProto, checkerCtx, lexicalScope);
}

} // namespace Test
} // namespace ONNX_NAMESPACE
