/*
 * SPDX-License-Identifier: Apache-2.0
 */

#include "onnx/defs/controlflow/utils.h"

#include <string>
#include <vector>

namespace ONNX_NAMESPACE {

void ClearShape(TypeProto& input_type) {
  if (input_type.has_tensor_type()) {
    input_type.mutable_tensor_type()->clear_shape();
  } else if (input_type.has_sequence_type()) {
    auto& seq_type = *input_type.mutable_sequence_type();
    if (seq_type.has_elem_type()) {
      ClearShape(*(seq_type.mutable_elem_type()));
    }
  } else if (input_type.has_optional_type()) {
    auto& opt_type = *input_type.mutable_optional_type();
    if (opt_type.has_elem_type()) {
      ClearShape(*(opt_type.mutable_elem_type()));
    }
  }
}

void IfInferenceFunction(InferenceContext& ctx) {
  // there are no inputs so we just need to run the subgraph inferencing for
  // then/else subgraphs and apply those to the outputs.
  std::vector<const TypeProto*> subgraph_input_types; // none
  std::vector<const TensorProto*> input_data; // none

  std::vector<const TypeProto*> then_output_types;
  std::vector<const TypeProto*> else_output_types;

  // Run inferencing on the subgraph
  GraphInferencer* graphInferencer = ctx.getGraphAttributeInferencer("then_branch");
  if (graphInferencer) {
    then_output_types = graphInferencer->doInferencing(subgraph_input_types, input_data);
  }

  graphInferencer = ctx.getGraphAttributeInferencer("else_branch");
  if (graphInferencer) {
    else_output_types = graphInferencer->doInferencing(subgraph_input_types, input_data);
  }

  auto num_outputs = ctx.getNumOutputs();
  auto num_then_outputs = then_output_types.size();
  auto num_else_outputs = else_output_types.size();

  // the output types for then and else should be the same
  if (num_then_outputs != num_else_outputs) {
    fail_type_inference(
        "then_branch and else_branch produce different number of outputs. ",
        num_then_outputs,
        " != ",
        num_else_outputs);
  }

  if (num_then_outputs != num_outputs) {
    fail_type_inference("If node has ", num_outputs, " but subgraphs produce ", num_then_outputs);
  }

  for (size_t i = 0, end = then_output_types.size(); i < end; ++i) {
    auto then_output = then_output_types[i];
    auto else_output = else_output_types[i];

    auto* if_output = ctx.getOutputType(i);
    *if_output = *then_output;

    UnionTypeInfo(*else_output, *if_output);
  }
}

void LoopInferenceFunction(InferenceContext& ctx) {
  auto num_inputs = ctx.getNumInputs();
  assert(num_inputs >= 2);
  auto num_loop_state_vars = num_inputs - 2; // skip 'M' and 'cond'

  std::vector<const TypeProto*> subgraph_input_types;
  subgraph_input_types.reserve(num_inputs);

  std::vector<TypeProto> temporary_type_protos;
  temporary_type_protos.reserve(num_inputs - 2);

  // create TypeProto to validate iteration number type is the same as the
  // optional 'M' input for max iterations.
  TypeProto iter_num_type;
  iter_num_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT64);
  subgraph_input_types.push_back(&iter_num_type);

  // 'cond'
  subgraph_input_types.push_back(ctx.getInputType(1));

  // loop state value types get propagated to outputs, but shape may change
  // across iterations so don't propagate it to the outputs and don't pass it
  // into the subgraph inferencing
  for (size_t i = 2; i < num_inputs; ++i) {
    propagateElemTypeFromInputToOutput(ctx, i, i - 2);

    // copy so we can remove the shape before passing to the subgraph
    // inferencing
    temporary_type_protos.push_back(*ctx.getInputType(i));
    auto& input_type = temporary_type_protos.back();

    ClearShape(input_type);
    subgraph_input_types.push_back(&input_type);
  }

  // Run inferencing on the subgraph
  std::vector<const TypeProto*> subgraph_output_types;

  GraphInferencer* graphInferencer = ctx.getGraphAttributeInferencer("body");
  if (graphInferencer) {
    std::vector<const TensorProto*> input_data;
    input_data.push_back(nullptr); // iteration number
    for (size_t i = 1; i < num_inputs; ++i) {
      input_data.push_back(ctx.getInputData(i));
    }

    subgraph_output_types = graphInferencer->doInferencing(subgraph_input_types, input_data);
  }

  // if empty(), assume inferencing was skipped
  if (!subgraph_output_types.empty()) {
    auto num_outputs = ctx.getNumOutputs();

    // subgraph outputs the condition value first but that is only used
    // internally and not returned by Loop.
    if (subgraph_output_types.size() != num_outputs + 1) {
      fail_type_inference(
          "Graph attribute inferencing returned type information for ",
          subgraph_output_types.size(),
          " outputs. Expected ",
          num_outputs + 1);
    }

    // check loop state values match. we should already have type/shape info
    for (size_t i = 0; i < num_outputs; ++i) {
      auto* subgraph_output_type = subgraph_output_types[i + 1]; // skip 'cond'
      auto* loop_output_type = ctx.getOutputType(i);

      const bool is_loop_state_var = i < num_loop_state_vars;

      if (!subgraph_output_type->has_tensor_type() && !subgraph_output_type->has_sequence_type() &&
          !subgraph_output_type->has_optional_type()) {
        fail_type_inference(
            "Loop 'body' subgraph outputs should all be tensors or sequences or optionals, but output ",
            i,
            " was ",
            subgraph_output_type->value_case());
      }

      if (!is_loop_state_var && !subgraph_output_type->has_tensor_type()) {
        fail_type_inference(
            "Loop 'body' subgraph scan outputs should all be tensors but output ",
            i,
            " was ",
            subgraph_output_type->value_case());
      }

      // if there's an existing type check it matches. otherwise propagate
      propagateElemTypeWithValidation(subgraph_output_type, loop_output_type);

      if (is_loop_state_var) {
        // shape may change across iterations so ignore.
      } else {
        // propagate shape
        if (subgraph_output_type->tensor_type().has_shape()) {
          // per iteration output. first dimension will be number of iterations
          // but we don't know that value yet
          TypeProto inferred_type(*subgraph_output_type);
          auto* mutable_inferred_tensor_type = inferred_type.mutable_tensor_type();
          auto* mutable_inferred_shape = mutable_inferred_tensor_type->mutable_shape();

          mutable_inferred_shape->clear_dim();

          // add empty dimension for number of iterations
          mutable_inferred_shape->add_dim();

          // add dimensions from subgraph output shape
          for (const auto& dim : subgraph_output_type->tensor_type().shape().dim()) {
            (*mutable_inferred_shape->add_dim()) = dim;
          }

          mergeInShapeInfo(*mutable_inferred_tensor_type, *loop_output_type->mutable_tensor_type());
        }
      }
    }
  }
}

int handle_negative_axis_validate(const std::string& attrib, int axis, int rank) {
  if (-rank > axis || axis >= rank) {
    fail_shape_inference(attrib, " axis value ", axis, " is invalid for a tensor of rank ", rank);
  }
  return (axis >= 0 ? axis : axis + rank);
}

void ScanInferenceFunction(InferenceContext& ctx) {
  auto num_inputs = ctx.getNumInputs();
  auto num_scan_inputs = narrow_cast<size_t>(ctx.getAttribute("num_scan_inputs")->i());
  auto num_loop_state_vars = num_inputs - num_scan_inputs;
  auto num_outputs = ctx.getNumOutputs();
  auto num_scan_outputs = num_outputs - num_loop_state_vars;

  std::vector<int64_t> axes, output_axes;
  if (getRepeatedAttribute(ctx, "scan_input_axes", axes)) {
    if (axes.size() != num_scan_inputs) {
      fail_shape_inference(
          "Number of scan input axes specified (",
          axes.size(),
          ") is not equal to number of scan inputs (",
          num_scan_inputs,
          ").");
    }
  } else {
    axes.insert(axes.end(), num_scan_inputs, 0);
  }

  if (getRepeatedAttribute(ctx, "scan_output_axes", output_axes)) {
    if (output_axes.size() != num_scan_outputs) {
      fail_shape_inference(
          "Number of scan output axes specified (",
          output_axes.size(),
          ") is not equal to number of scan outputs (",
          num_scan_outputs,
          ").");
    }
  } else {
    output_axes.insert(output_axes.end(), num_scan_outputs, 0);
  }

  std::vector<TypeProto> temporary_type_protos;
  temporary_type_protos.reserve(num_inputs);

  std::vector<const TypeProto*> subgraph_input_types;
  subgraph_input_types.reserve(num_inputs);

  TensorShapeProto_Dimension sequence_len_dim;

  for (size_t i = 0; i < num_inputs; ++i) {
    bool is_loop_state_var = i < num_loop_state_vars;
    bool has_shape = hasInputShape(ctx, i);
    const auto* input_type = ctx.getInputType(i);

    // Enforce type constraint for inputs
    if (!input_type || !input_type->has_tensor_type()) {
      fail_type_inference("Scan input ", i, " was not a tensor.");
    }

    if (is_loop_state_var) {
      // If it's a loop state variable we can propagate type and shape 1:1 to
      // the matching Scan output.
      // We can also pass through the type and shape to the subgraph but need to
      // remove the batch size dimension from the shape.
      propagateElemTypeFromInputToOutput(ctx, i, i);
      if (has_shape)
        propagateShapeFromInputToOutput(ctx, i, i);

      subgraph_input_types.push_back(input_type);
    } else {
      // For other inputs there is no fixed relationships to the Scan outputs,
      // so we don't propagate type/shape information.
      // We can pass through the type and shape to the subgraph inputs but
      // need to remove the sequence length dimensions from the shape.
      if (has_shape) {
        const auto& shape = input_type->tensor_type().shape();

        // remove sequence length dimensions and add to subgraph_input_types
        int axis = static_cast<int>(axes[i - num_loop_state_vars]);
        axis = handle_negative_axis_validate("scan_input_axes", axis, shape.dim_size());

        // update sequence_len if a value is available

        const auto& dims = shape.dim();
        mergeInDimensionInfo(dims.Get(axis), sequence_len_dim, 1);

        temporary_type_protos.emplace_back(RemoveIthDimensionFromShape(*input_type, axis));
        subgraph_input_types.emplace_back(&temporary_type_protos.back());

      } else {
        subgraph_input_types.push_back(input_type);
      }
    }
  }

  // Run inferencing on the subgraph
  std::vector<const TypeProto*> output_types;

  GraphInferencer* graphInferencer = ctx.getGraphAttributeInferencer("body");
  if (graphInferencer) {
    std::vector<const TensorProto*> input_data;
    input_data.reserve(num_inputs);
    for (size_t i = 0; i < num_inputs; ++i) {
      // ctx.getInputData(i), the input to scan, does not represent the input to
      // scan body. So, we pass in null, to represent an unknown value.
      input_data.push_back(nullptr);
    }

    output_types = graphInferencer->doInferencing(subgraph_input_types, input_data);
  }

  // if empty(), assume inferencing was skipped
  if (!output_types.empty()) {
    if (output_types.size() != num_outputs) {
      fail_type_inference(
          "Graph attribute inferencing returned type information for ",
          output_types.size(),
          " outputs. Expected ",
          num_outputs);
    }

    // propagate type/shape information for loop state variables and outputs
    for (size_t i = 0; i < num_outputs; ++i) {
      const bool is_loop_state_var = i < num_loop_state_vars;
      auto* subgraph_output_type = output_types[i];
      auto* scan_output_type = ctx.getOutputType(i);
      auto* mutable_scan_output_tensor_type = scan_output_type->mutable_tensor_type();

      if (!subgraph_output_type->has_tensor_type()) {
        fail_type_inference("Scan 'body' subgraph outputs should all be tensors but output ", i, " was not");
      }
      auto& subgraph_output_tensor_type = subgraph_output_type->tensor_type();

      if (is_loop_state_var) {
        // merge shape; type already propagated
        mergeInShapeInfo(subgraph_output_tensor_type, *mutable_scan_output_tensor_type);
      } else {
        scan_output_type->mutable_tensor_type()->set_elem_type(subgraph_output_tensor_type.elem_type());

        // propagate shape
        if (subgraph_output_tensor_type.has_shape()) {
          // infer shape of scan-output from the shape of scan-output-element
          // by adding sequence-length at the correct axis position
          const TensorShapeProto& subgraph_output_shape = subgraph_output_tensor_type.shape();
          TensorShapeProto inferred_shape;

          auto subgraph_output_rank = subgraph_output_shape.dim_size();
          auto output_rank = subgraph_output_rank + 1;
          int output_axis = static_cast<int>(output_axes[i - num_loop_state_vars]);
          output_axis = handle_negative_axis_validate("scan_output_axes", output_axis, output_rank);

          for (int j = 0; j < output_axis; ++j)
            *(inferred_shape.add_dim()) = subgraph_output_shape.dim(j);
          *(inferred_shape.add_dim()) = sequence_len_dim;
          for (int j = output_axis; j < subgraph_output_rank; ++j)
            *(inferred_shape.add_dim()) = subgraph_output_shape.dim(j);

          // Merge inferred shape with existing shape information
          mergeInShapeInfo(inferred_shape, *mutable_scan_output_tensor_type);
        }
      }
    }
  }
}

} // namespace ONNX_NAMESPACE
