/*
 * SPDX-License-Identifier: Apache-2.0
 */

#include "onnx/defs/reduction/utils.h"

#include <algorithm>
#include <string>
#include <vector>

namespace ONNX_NAMESPACE {

static std::vector<std::string> GetSupportedDataTypesForReductionOps(bool supports8bit, bool supports_bool) {
  auto data_types = OpSchema::numeric_types_for_math_reduction_ir4();
  if (supports8bit) {
    data_types.emplace_back("tensor(uint8)");
    data_types.emplace_back("tensor(int8)");
  }
  if (supports_bool) {
    data_types.emplace_back("tensor(bool)");
  }

  return data_types;
}

std::function<void(OpSchema&)> ReduceOpGenerator(
    const char* name,
    const char* empty_value,
    bool supports_8bit_datatypes,
    bool axes_input,
    const char* func_body,
    const ContextDependentFunctionBodyBuilder& function_builder,
    bool supports_boolean_datatype /* = false */) {
  return [=](OpSchema& schema) {
    std::string doc = R"DOC(
Computes the {name} of the input tensor's elements along the provided axes. The resulting
tensor has the same rank as the input if `keepdims` equals 1. If `keepdims` equals 0, then
the resulting tensor has the reduced dimension pruned. Input tensors of rank zero are
valid. Reduction over an empty set of values yields {empty_value}.
)DOC";
    if (supports_boolean_datatype) {
      doc += R"DOC(

If the input data type is Boolean, the comparison should consider `False < True`.)DOC";
    }
    doc += R"DOC(

The above behavior is similar to numpy, with the exception that numpy defaults `keepdims`
to `False` instead of `True`.)DOC";

    ReplaceAll(doc, "{name}", name);
    ReplaceAll(doc, "{empty_value}", empty_value);
    POPULATE_OP_DOC_STR(doc = doc;);
    schema.SetDoc(doc.c_str());
    schema.Attr(
        "keepdims",
        "Keep the reduced dimension or not, default 1 means keep reduced dimension.",
        AttributeProto::INT,
        static_cast<int64_t>(1));
    schema.Input(0, "data", "An input tensor.", "T", OpSchema::Single, true, 1, OpSchema::Differentiable);
    if (axes_input) {
      schema.Attr(
          "noop_with_empty_axes",
          "Defines behavior when axes is not provided or is empty. "
          "If false (default), reduction happens over all axes (similar to the case "
          "when `axis=None` in numpy). "
          "If true, reduction happens over an empty set of axes (similar to the case "
          "when `axis=()` in numpy). "
          "Note that reduction over an empty set of axes means that the reduction step "
          "behaves like a no-op (identity function), but composite-reduction operators "
          "will still perform the non-reduction steps as needed. "
          "Thus, ReduceLogSum returns the Log of input tensor, and ReduceSumSquare "
          "returns the Square of the input tensor, in this case.",
          AttributeProto::INT,
          static_cast<int64_t>(0));
      schema.Input(
          1,
          "axes",
          "Optional input list of integers, along which to reduce. "
          "The default is to reduce over empty axes. "
          "When axes is empty (either not provided or explicitly empty), behavior depends on 'noop_with_empty_axes': "
          "reduction over all axes if 'noop_with_empty_axes' is false, "
          "and reduction over the empty set of axes when 'noop_with_empty_axes' is true. "
          "Accepted range is [-r, r-1] where r = rank(data).",
          "tensor(int64)",
          OpSchema::Optional,
          true,
          1,
          OpSchema::NonDifferentiable);
    } else {
      schema.Attr(
          "axes",
          "A list of integers, along which to reduce. The default is to reduce over "
          "all the dimensions of the input tensor. Accepted range is [-r, r-1] where r = rank(data).",
          AttributeProto::INTS,
          OPTIONAL_VALUE);
    }
    schema.Output(0, "reduced", "Reduced output tensor.", "T", OpSchema::Single, true, 1, OpSchema::Differentiable);
    schema.TypeConstraint(
        "T",
        GetSupportedDataTypesForReductionOps(supports_8bit_datatypes, supports_boolean_datatype),
        supports_boolean_datatype ? "Constrain input and output types to numeric and Boolean tensors."
                                  : "Constrain input and output types to numeric tensors.");
    if (func_body) {
      schema.FunctionBody(func_body);
    } else if (function_builder) {
      schema.SetContextDependentFunctionBodyBuilder(function_builder);
    }
    schema.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
      propagateElemTypeFromInputToOutput(ctx, 0, 0);
      if (!hasNInputShapes(ctx, 1)) {
        return;
      }

      int64_t keep_dims = 1, noop_with_empty_axes = 0;
      auto attr_proto = ctx.getAttribute("keepdims");
      if (attr_proto) {
        keep_dims = attr_proto->i();
      }
      auto noop_attr_proto = ctx.getAttribute("noop_with_empty_axes");
      if (noop_attr_proto) {
        noop_with_empty_axes = noop_attr_proto->i();
      }
      std::vector<int64_t> axes;
      if (ctx.hasInput(1)) { // axes is input
        if (ctx.getAttribute("axes")) {
          fail_shape_inference("axes as an input and attribute cannot be specified at the same time.");
        }

        const TensorProto* axesInitializer = ctx.getInputData(1);
        if (axesInitializer == nullptr) {
          // skip if axes is not an initializer
          return;
        }
        std::vector<int64_t> axes_values = ParseData<int64_t>(axesInitializer);
        axes.assign(axes_values.begin(), axes_values.end());
      } else { // axes is attribute
        auto axes_proto = ctx.getAttribute("axes");
        if (axes_proto)
          axes.assign(axes_proto->ints().begin(), axes_proto->ints().end());
      }
      auto& input_shape = ctx.getInputType(0)->tensor_type().shape();
      if (noop_with_empty_axes && axes.empty()) {
        propagateShapeFromInputToOutput(ctx, 0, 0);
        return;
      }
      int64_t input_ndim = input_shape.dim_size();
      auto output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape();

      for (int64_t& axe : axes) {
        if (axe < -input_ndim || axe >= input_ndim) {
          fail_shape_inference("axis must be in [-rank, rank-1]. Input rank was ", input_ndim);
        }
        if (axe < 0)
          axe += input_ndim;
      }
      for (int i = 0; i < input_ndim; ++i) {
        // axes empty means reduce all dim
        if (!axes.empty() && std::find(axes.begin(), axes.end(), i) == axes.end()) {
          auto dim = output_shape->add_dim();
          dim->CopyFrom(input_shape.dim(i));
        } else {
          if (keep_dims == 1) {
            auto dim = output_shape->add_dim();
            dim->set_dim_value(1);
          }
        }
      }
    });
  };
}
} // namespace ONNX_NAMESPACE
