// Copyright (c) ONNX Project Contributors

/*
 * SPDX-License-Identifier: Apache-2.0
 */

#include "onnx/version_converter/convert.h"

#include <iostream>
#include <memory>
#include <string>

#include "onnx/common/ir_pb_converter.h"

namespace ONNX_NAMESPACE {
namespace version_conversion {

ModelProto ConvertVersion(const ModelProto& mp_in, int target_version) {
  // Get initial_opsetid from mp_in
  OpSetID initial_struct(0);
  for (const auto& it : mp_in.opset_import()) {
    if (it.domain().empty() || it.domain() == "ai.onnx") {
      initial_struct.setVersion(it.version());
      break;
    }
  }
  OpSetID target_struct = OpSetID(target_version);
  DefaultVersionConverter v;
  return v.convert_version(mp_in, initial_struct, target_struct);
}

void DefaultVersionConverter::convert_graph(
    const std::shared_ptr<Graph>& g,
    const OpSetID& initial_version,
    const OpSetID& target_version) const {
  assertNonNull(g);

  // TODO: Move to Inter-Domain Converter
  // Get initial model versions
  // std::vector<OpSetID> initial_versions = g->opset_versions_mutable();

  // No conversion necessary if Model has single, equivalent opset version
  // if (initial_versions.size() == 1 && initial_versions[0].version ==
  //    target_version.version && initial_versions[0].domain ==
  //    target_version.domain) {
  //  return mp_in;
  // }

  // Check if versions are valid
  assertInVersionRange(initial_version.version());
  assertInVersionRange(target_version.version());

  // Iterate over all versions to target_version for specified
  int64_t curr_version = initial_version.version();
  int64_t step = 0;
  if (target_version.version() > initial_version.version()) {
    step = 1;
  } else {
    step = -1;
  }
  // Identify index of this domain in g.opset_versions
  unsigned int domain_index = 0;
  for (unsigned int i = 0; i < g->opset_versions_mutable().size(); i++) {
    if (g->opset_versions_mutable()[i].domain().empty()) {
      domain_index = i;
    }
  }
  while (curr_version != target_version.version()) {
    debug(
        "curr_version: " + ONNX_NAMESPACE::to_string(curr_version) +
        ", next_version: " + ONNX_NAMESPACE::to_string(curr_version + step));
    Node* cur_op = nullptr;
    graph_node_list_iterator it = g->begin();
    // Iterate through and call adapter returned by adapter_lookup for ops from
    // current_version opset. We have to manipulate the iterator explicitly because cur_op
    // might change when applying the adapter (e.g. for deprecated ops)
    while (it != g->end()) {
      cur_op = *it;
      debug(std::string("Finding schema for ") + std::string(cur_op->kind().toString()));
      const std::string op_name = cur_op->kind().toString();
      if (op_name == "ConstantFill") {
        if (DEBUG) {
          std::cerr
              << "Warning: skipping schema search for experimental op 'ConstantFill' and keeping the op as is. "
                 "Please be advised the converted model may not be working properly if target runtime does not support this "
                 "experimental op."
              << '\n';
        }
      } else if (!cur_op->domain().empty() && cur_op->domain() != "ai.onnx") {
        if (DEBUG) {
          std::cerr << "Warning: opset domain '" << cur_op->domain() << "' is not supported." << '\n';
        }
      } else if (op_name != "Undefined" && op_name != "Captured") {
        auto& op_domain_map = all_schemas.at(op_name);
        OpSetID curr_id(curr_version);
        OpSetID next_id(curr_version + step);
        if (searchOpDomainMap(op_domain_map, curr_version, step)) {
          // Op is specifically defined for this domain and version
          auto& op_adapter = adapter_lookup(cur_op, curr_id, next_id);
          // If adapter_lookup returns null, no adapter is present.
          // Error thrown by adapter_lookup
          if (DEBUG) {
            std::cerr << "Applying adapter" << '\n';
          }
          // adapt should handle replacing node in graph
          cur_op = op_adapter.adapt(g, cur_op);
          it = graph_node_list_iterator(cur_op, kNextDirection);
        }
        // Recursively convert any subgraph attributes
        for (const auto& attr : cur_op->attributeNames()) {
          if (cur_op->kindOf(attr) == AttributeKind::g) {
            convert_graph(cur_op->g(attr), curr_id, next_id);
          }
        }
      }
      it++;
    }
    // Update model version
    curr_version += step;
    g->opset_versions_mutable()[domain_index].incrementVersion(step);
  }
}

ModelProto DefaultVersionConverter::convert_version(
    const ModelProto& mp_in,
    const OpSetID& initial_version,
    const OpSetID& target_version) const {
  const std::string& initial_domain = initial_version.domain();
  const std::string& target_domain = target_version.domain();
  assertDefaultDomain(initial_domain, target_domain);

  for (auto it = mp_in.opset_import().begin(); it != mp_in.opset_import().end(); ++it) {
    if (it->domain() == initial_version.domain()) {
      ONNX_ASSERTM(
          initial_version.version() == it->version(), "initial_version does not reflect current state of model")
    }
  }

  std::shared_ptr<Graph> g(ImportModelProto(mp_in));

  convert_graph(g, initial_version, target_version);

  // Export g as ModelProto
  debug("Finished conversion; returning model");
  ModelProto mp_out = PrepareOutput(mp_in);
  ExportModelProto(&mp_out, g);
  return mp_out;
}

} // namespace version_conversion
} // namespace ONNX_NAMESPACE
