// Copyright (c) ONNX Project Contributors

/*
 * SPDX-License-Identifier: Apache-2.0
 */

// ATTENTION: The code in this file is highly EXPERIMENTAL.
// Adventurous users should note that the APIs will probably change.

#include "onnx/common/ir_pb_converter.h"

#include <memory>
#include <sstream>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>

namespace ONNX_NAMESPACE {

// Part 1: convert ONNX Protobuf to IR
static std::unique_ptr<Graph> graphProtoToGraph(const GraphProto& gp, bool nested, const int ir_version = IR_VERSION);

static Tensor tensorProtoToTensor(const ONNX_NAMESPACE::TensorProto& tp) {
  Tensor ret;

  ret.sizes().reserve(tp.dims_size());
  for (int i = 0; i < tp.dims_size(); i++) {
    ret.sizes().push_back(tp.dims(i));
  }

  ret.elem_type() = tp.data_type();
  switch (tp.data_type()) {
    case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
    case ONNX_NAMESPACE::TensorProto_DataType_COMPLEX64: {
      ret.floats().reserve(tp.float_data_size());
      for (int i = 0; i < tp.float_data_size(); i++) {
        ret.floats().push_back(tp.float_data(i));
      }
      break;
    }
    case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
    case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16:
    case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
    case ONNX_NAMESPACE::TensorProto_DataType_UINT2:
    case ONNX_NAMESPACE::TensorProto_DataType_INT2:
    case ONNX_NAMESPACE::TensorProto_DataType_UINT4:
    case ONNX_NAMESPACE::TensorProto_DataType_INT4:
    case ONNX_NAMESPACE::TensorProto_DataType_INT8:
    case ONNX_NAMESPACE::TensorProto_DataType_INT16:
    case ONNX_NAMESPACE::TensorProto_DataType_INT32:
    case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
    case ONNX_NAMESPACE::TensorProto_DataType_UINT16:
    case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN:
    case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ:
    case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2:
    case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ:
    case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E8M0:
    case ONNX_NAMESPACE::TensorProto_DataType_FLOAT4E2M1: {
      ret.int32s().reserve(tp.int32_data_size());
      for (int i = 0; i < tp.int32_data_size(); i++) {
        ret.int32s().push_back(tp.int32_data(i));
      }
      break;
    }
    case ONNX_NAMESPACE::TensorProto_DataType_INT64: {
      ret.int64s().reserve(tp.int64_data_size());
      for (int i = 0; i < tp.int64_data_size(); i++) {
        ret.int64s().push_back(tp.int64_data(i));
      }
      break;
    }
    case ONNX_NAMESPACE::TensorProto_DataType_UINT32:
    case ONNX_NAMESPACE::TensorProto_DataType_UINT64: {
      ret.uint64s().reserve(tp.uint64_data_size());
      for (int i = 0; i < tp.uint64_data_size(); i++) {
        ret.uint64s().push_back(tp.uint64_data(i));
      }
      break;
    }
    case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE:
    case ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128: {
      ret.doubles().reserve(tp.double_data_size());
      for (int i = 0; i < tp.double_data_size(); i++) {
        ret.doubles().push_back(tp.double_data(i));
      }
      break;
    }
    case ONNX_NAMESPACE::TensorProto_DataType_STRING: {
      ret.strings().reserve(tp.string_data_size());
      for (int i = 0; i < tp.string_data_size(); i++) {
        ret.strings().push_back(tp.string_data(i));
      }
      break;
    }
    case ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED:
      [[fallthrough]];
    default:
      fail_convert("Unknown tensor data type");
  }

  // The only way to know if we should be using raw_data or
  // <type>_data is to look at which of them is size zero.
  if (tp.has_raw_data()) {
    ret.set_raw_data(tp.raw_data());
  }

  if (tp.has_name()) {
    ret.setName(tp.name());
  }
  if (tp.has_segment()) {
    ret.set_segment_begin_and_end(tp.segment().begin(), tp.segment().end());
  }

  for (int i = 0; i < tp.external_data_size(); i++) {
    ret.external_data().emplace_back(tp.external_data(i).key(), tp.external_data(i).value());
  }
  if (tp.has_data_location()) {
    ret.data_location() = tp.data_location();
  }
  return ret;
}

static void convertAttribute(const ONNX_NAMESPACE::AttributeProto& ap, Node* n, const int ir_version = IR_VERSION) {
  Symbol sym = Symbol(ap.name());
  switch (ap.type()) {
    case ONNX_NAMESPACE::AttributeProto_AttributeType_FLOAT:
      n->f_(sym, ap.f());
      break;
    case ONNX_NAMESPACE::AttributeProto_AttributeType_FLOATS: {
      std::vector<double> floats;
      floats.reserve(ap.floats_size());
      for (int i = 0; i < ap.floats_size(); i++) {
        floats.push_back(ap.floats(i));
      }
      n->fs_(sym, std::move(floats));
      break;
    }
    case ONNX_NAMESPACE::AttributeProto_AttributeType_INT:
      n->i_(sym, ap.i());
      break;
    case ONNX_NAMESPACE::AttributeProto_AttributeType_INTS: {
      std::vector<int64_t> ints;
      ints.reserve(ap.ints_size());
      for (int i = 0; i < ap.ints_size(); i++) {
        ints.push_back(ap.ints(i));
      }
      n->is_(sym, std::move(ints));
      break;
    }
    case ONNX_NAMESPACE::AttributeProto_AttributeType_STRING:
      n->s_(sym, ap.s());
      break;
    case ONNX_NAMESPACE::AttributeProto_AttributeType_STRINGS: {
      std::vector<std::string> strings;
      strings.reserve(ap.strings_size());
      for (int i = 0; i < ap.strings_size(); i++) {
        strings.push_back(ap.strings(i));
      }
      n->ss_(sym, std::move(strings));
      break;
    }
    case ONNX_NAMESPACE::AttributeProto_AttributeType_TENSOR:
      n->t_(sym, tensorProtoToTensor(ap.t()));
      break;
    case ONNX_NAMESPACE::AttributeProto_AttributeType_TENSORS: {
      std::vector<Tensor> tensors;
      tensors.reserve(ap.tensors_size());
      for (int i = 0; i < ap.tensors_size(); i++) {
        tensors.emplace_back(tensorProtoToTensor(ap.tensors(i)));
      }
      n->ts_(sym, std::move(tensors));
      break;
    }
    case ONNX_NAMESPACE::AttributeProto_AttributeType_TYPE_PROTO:
      n->tp_(sym, ap.tp());
      break;
    case ONNX_NAMESPACE::AttributeProto_AttributeType_TYPE_PROTOS: {
      std::vector<TypeProto> types;
      types.reserve(ap.type_protos_size());
      for (int i = 0; i < ap.type_protos_size(); i++) {
        types.push_back(ap.type_protos(i));
      }
      n->tps_(sym, std::move(types));
      break;
    }
    case ONNX_NAMESPACE::AttributeProto_AttributeType_GRAPH:
      n->g_(sym, graphProtoToGraph(ap.g(), true, ir_version));
      break;
    case ONNX_NAMESPACE::AttributeProto_AttributeType_GRAPHS: {
      std::vector<std::shared_ptr<Graph>> graphs;
      graphs.reserve(ap.graphs_size());
      for (int i = 0; i < ap.graphs_size(); i++) {
        graphs.push_back(graphProtoToGraph(ap.graphs(i), true, ir_version));
      }
      n->gs_(sym, std::move(graphs));
      break;
    }
    case ONNX_NAMESPACE::AttributeProto_AttributeType_SPARSE_TENSOR:
    case ONNX_NAMESPACE::AttributeProto_AttributeType_SPARSE_TENSORS:
      fail_convert("Sparse tensors not supported.");
      break;
    case ONNX_NAMESPACE::AttributeProto_AttributeType_UNDEFINED:
      fail_convert("Unknown tensor data type");
      break;
  }
}

static void convertAttributes(const ONNX_NAMESPACE::NodeProto& np, Node* n, const int ir_version = IR_VERSION) {
  for (int i = 0; i < np.attribute_size(); i++) {
    convertAttribute(np.attribute(i), n, ir_version);
  }
}

static std::vector<Dimension> tensorShapeProtoToDimensions(const ONNX_NAMESPACE::TensorShapeProto& tsp) {
  std::vector<Dimension> dims;
  dims.reserve(tsp.dim_size());
  for (int i = 0; i < tsp.dim_size(); i++) {
    if (tsp.dim(i).has_dim_value()) {
      dims.emplace_back(tsp.dim(i).dim_value());
    } else if (tsp.dim(i).has_dim_param()) {
      dims.emplace_back(tsp.dim(i).dim_param());
    } else {
      // a dimension that has neither dim_value nor dim_param set
      // represents an unknown dimension unrelated to other unknown dimensions.
      dims.emplace_back();
    }
  }
  return dims;
}

static void createDummyValue(
    const std::unique_ptr<Graph>& g,
    const std::string& name,
    std::unordered_map<std::string, Value*>& value_by_name_of) {
  auto* undef = g->create(kCaptured, 1);
  g->appendNode(undef);
  undef->outputs()[0]->setUniqueName(name);
  value_by_name_of[name] = undef->outputs()[0];
}

std::unique_ptr<Graph> graphProtoToGraph(const ONNX_NAMESPACE::GraphProto& gp, bool nested, const int ir_version) {
  auto g = std::make_unique<Graph>();

  if (gp.has_name()) {
    g->setName(gp.name());
  }
  if (gp.has_doc_string()) {
    g->setDocString(gp.doc_string());
  }

  // Values are created (as in `new Value(..)`) by the Node that
  // outputs them. Therefore we initialize the Nodes and Values in
  // several stages.
  //
  // 1) add all input (to the graph) Values, owned by the sentinel Param node
  // 2) add all Nodes and their output Values, but don't initialize inputs
  // 3) initialize inputs of all Nodes
  // 4) initialize inputs of the Return sentinel node
  // 5) fill in type info for graph outputs, and register them as outputs
  // 6) fill in type info for Values from the value_info list in the graph

  // In ONNX proto land, Values are just strings. We are going to make
  // objects out of them, and equal strings must be mapped to the same
  // Value object.
  std::unordered_map<std::string, Value*> value_by_name_of;

  // We initialize Node inputs in a separate pass from the Nodes
  // themselves. To do so, we need to have access to the names of the
  // inputs.
  std::unordered_map<Node*, std::vector<std::string>> inputs_by_node;

  {
    // ONNX represents optional arguments in two ways
    //  - they are simply not provided
    //  - OR the empty string is passed as the input name
    // This is to handle that second case, which needs a dummy node to
    // be representable in the graph IR.
    auto* n = g->create(kUndefined, 1);
    g->appendNode(n);
    n->outputs()[0]->setUniqueName("");
    value_by_name_of[""] = n->outputs()[0];
  }

  for (int i = 0; i < gp.input_size(); i++) {
    const auto& vip = gp.input(i);
    auto v = g->addInput();
    const auto& tensor_type = vip.type().tensor_type();
    if (tensor_type.has_elem_type()) {
      v->setElemType(tensor_type.elem_type());
    }
    if (tensor_type.has_shape()) {
      v->setSizes(tensorShapeProtoToDimensions(tensor_type.shape()));
    }
    v->setUniqueName(vip.name());
    value_by_name_of[vip.name()] = v;
  }

  // initializers should be added before all nodes,
  // otherwise getNextUnique() may conflicts with an existing initializer name.
  for (int i = 0; i < gp.initializer_size(); ++i) {
    auto init = tensorProtoToTensor(gp.initializer(i));
    // If ir_version >= 4, initializer does not have to be included in input
    // Create a Value from initializer by addInitializerNode if name does not exist in input
    // and save it into value_by_name_of for later use (node input)
    if (ir_version >= 4 && value_by_name_of.count(init.name()) == 0) {
      value_by_name_of[init.name()] = g->addInitializerAndCreateValue(init);
    } else {
      // If ir_version < 4 or the initializer exists in input
      // Simply add initializer without creating new value
      // which means it will prioritize input value over initializer value if both exist
      g->addInitializer(init);
    }
  }

  for (int i = 0; i < gp.node_size(); i++) {
    const auto& np = gp.node(i);
    auto* n = g->create(Symbol(np.op_type()), /* num_outputs = */ np.output_size());
    g->appendNode(n);
    for (int j = 0; j < np.output_size(); j++) {
      auto out = n->outputs()[j];
      // we don't know the real type here, so that's done in a later pass
      out->setElemType(ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED);
      out->setUniqueName(np.output(j));
      value_by_name_of[np.output(j)] = out;
    }
    convertAttributes(np, n, ir_version);
    std::vector<std::string> inputs;
    inputs.reserve(np.input_size());
    for (int j = 0; j < np.input_size(); j++) {
      inputs.push_back(np.input(j));
    }
    inputs_by_node[n] = inputs;
    if (np.has_doc_string()) {
      n->setDocString(np.doc_string());
    }
    if (np.has_name()) {
      n->setName(np.name());
    }
    if (np.has_domain()) {
      n->setDomain(np.domain());
    }
    if (np.has_overload()) {
      n->setOverload(np.overload());
    }
  }

  for (auto n : g->nodes()) {
    auto search = inputs_by_node.find(n);
    if (search == inputs_by_node.end()) {
      continue;
    }
    for (const auto& input : search->second) {
      if (!value_by_name_of.count(input) && nested) {
        // Undefined reference to an input in a nested block. This may be a
        // captured value. Create a dummy node that we ignore later.
        createDummyValue(g, input, value_by_name_of);
      }

      if (!value_by_name_of.count(input)) {
        std::ostringstream msg;
        msg << "Input " << input << " is undefined!";
        ONNX_THROW_EX(std::out_of_range(msg.str()));
      }
      n->addInput(value_by_name_of.at(input));
    }
  }

  for (int i = 0; i < gp.output_size(); i++) {
    if (!value_by_name_of.count(gp.output(i).name()) && nested) {
      // Same captured value logic as above. We can consider outputs of a
      // graph to be "inputs" of a dummy "output" node. The same lexical
      // scoping rules are valid here, thus we need to add a dummy node
      // in the case of an undefined reference
      createDummyValue(g, gp.output(i).name(), value_by_name_of);
    }
    const auto& output_tensor_type = gp.output(i).type().tensor_type();
    if (output_tensor_type.has_elem_type()) {
      value_by_name_of[gp.output(i).name()]->setElemType(output_tensor_type.elem_type());
    }
    if (output_tensor_type.has_shape()) {
      value_by_name_of[gp.output(i).name()]->setSizes(tensorShapeProtoToDimensions(output_tensor_type.shape()));
    }
    g->registerOutput(value_by_name_of[gp.output(i).name()]);
  }

  for (int i = 0; i < gp.value_info_size(); i++) {
    const auto& tensor_type = gp.value_info(i).type().tensor_type();
    if (!value_by_name_of.count(gp.value_info(i).name())) {
      // Ideally the model should not have a value_info whose name does not exist in the graph (unused); simply skip it
      continue;
    }
    if (tensor_type.has_elem_type()) {
      value_by_name_of[gp.value_info(i).name()]->setElemType(tensor_type.elem_type());
    }
    if (tensor_type.has_shape()) {
      value_by_name_of[gp.value_info(i).name()]->setSizes(tensorShapeProtoToDimensions(tensor_type.shape()));
    }
  }

  return g;
}

std::unique_ptr<Graph> ImportModelProto(const ModelProto& mp) {
  if (!mp.has_ir_version()) {
    return nullptr;
  }
  if (mp.ir_version() <= 1) {
    // ir_version=1 is not supported and ir_version=0 is illegal
    return nullptr;
  }

  std::unique_ptr<Graph> g(graphProtoToGraph(mp.graph(), false, mp.ir_version()));
  for (int i = 0; i < mp.opset_import_size(); i++) {
    OpSetID new_opset_version(mp.opset_import(i).domain(), mp.opset_import(i).version());
    g->forSelfAndEachSubGraph(
        [&new_opset_version](Graph* graph) { graph->opset_versions_mutable().emplace_back(new_opset_version); });
  }
  return g;
}

// Part 2: convert IR to ONNX Protobuf
static std::string value_name(const Value* n) {
  return n->uniqueName();
}

static void encodeGraph(GraphProto* p_g, const std::shared_ptr<Graph>& g);

static void encodeTensor(ONNX_NAMESPACE::TensorProto* p, const Tensor& tensor) {
  if (tensor.hasName()) {
    p->set_name(tensor.name());
  }
  if (tensor.is_segment()) {
    ONNX_NAMESPACE::TensorProto_Segment segment;
    segment.set_begin(tensor.segment_begin());
    segment.set_end(tensor.segment_end());
    p->mutable_segment()->CopyFrom(segment);
  }
  for (auto d : tensor.sizes()) {
    p->add_dims(d);
  }
  p->set_data_type(tensor.elem_type());
  switch (tensor.elem_type()) {
    case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
    case ONNX_NAMESPACE::TensorProto_DataType_COMPLEX64: {
      for (float x : tensor.floats()) {
        p->add_float_data(x);
      }
      break;
    }
    case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
    case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16:
    case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN:
    case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ:
    case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2:
    case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ:
    case ONNX_NAMESPACE::TensorProto_DataType_FLOAT4E2M1:
    case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
    case ONNX_NAMESPACE::TensorProto_DataType_UINT2:
    case ONNX_NAMESPACE::TensorProto_DataType_INT2:
    case ONNX_NAMESPACE::TensorProto_DataType_INT4:
    case ONNX_NAMESPACE::TensorProto_DataType_UINT4:
    case ONNX_NAMESPACE::TensorProto_DataType_INT8:
    case ONNX_NAMESPACE::TensorProto_DataType_INT16:
    case ONNX_NAMESPACE::TensorProto_DataType_INT32:
    case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
    case ONNX_NAMESPACE::TensorProto_DataType_UINT16: {
      for (int32_t x : tensor.int32s()) {
        p->add_int32_data(x);
      }
      break;
    }
    case ONNX_NAMESPACE::TensorProto_DataType_INT64: {
      for (int64_t x : tensor.int64s()) {
        p->add_int64_data(x);
      }
      break;
    }
    case ONNX_NAMESPACE::TensorProto_DataType_UINT32:
    case ONNX_NAMESPACE::TensorProto_DataType_UINT64: {
      for (uint64_t x : tensor.uint64s()) {
        p->add_uint64_data(x);
      }
      break;
    }
    case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE:
    case ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128: {
      for (double x : tensor.doubles()) {
        p->add_double_data(x);
      }
      break;
    }
    case ONNX_NAMESPACE::TensorProto_DataType_STRING: {
      for (const std::string& x : tensor.strings()) {
        p->add_string_data(x);
      }
      break;
    }
    case ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED:
      [[fallthrough]];
    default:
      fail_convert("Unknown tensor data type");
  }
  if (tensor.is_raw_data()) {
    p->set_raw_data(tensor.raw());
  }
  if (!tensor.external_data().empty()) {
    for (const auto& entry : tensor.external_data()) {
      auto* external_data = p->add_external_data();
      external_data->set_key(entry.first);
      external_data->set_value(entry.second);
    }
  }
  if (tensor.has_data_location()) {
    p->set_data_location(tensor.data_location());
  }
}

static void addAttribute(ONNX_NAMESPACE::NodeProto* n_p, Node* n, Symbol name) {
  auto attr = n_p->add_attribute();
  attr->set_name(name.toString());
  switch (n->kindOf(name)) {
    case AttributeKind::f: {
      attr->set_f(static_cast<float>(n->f(name)));
      attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_FLOAT);
    } break;
    case AttributeKind::fs: {
      attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_FLOATS);
      for (auto& v : n->fs(name))
        attr->add_floats(static_cast<float>(v));
    } break;
    case AttributeKind::i: {
      attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT);
      attr->set_i(n->i(name));
    } break;
    case AttributeKind::is: {
      attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INTS);
      for (auto& v : n->is(name))
        attr->add_ints(v);
    } break;
    case AttributeKind::s: {
      attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_STRING);
      attr->set_s(n->s(name));
    } break;
    case AttributeKind::ss: {
      attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_STRINGS);
      for (auto& v : n->ss(name))
        attr->add_strings(v);
    } break;
    case AttributeKind::t: {
      attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_TENSOR);
      auto t = attr->mutable_t();
      encodeTensor(t, n->t(name));
    } break;
    case AttributeKind::ts: {
      attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_TENSORS);
      for (auto& v : n->ts(name)) {
        auto t = attr->add_tensors();
        encodeTensor(t, v);
      }
    } break;
    case AttributeKind::g: {
      attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_GRAPH);
      auto g = attr->mutable_g();
      encodeGraph(g, n->g(name));
    } break;
    case AttributeKind::gs: {
      attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_GRAPHS);
      for (auto& v : n->gs(name)) {
        auto g = attr->add_graphs();
        encodeGraph(g, v);
      }
    } break;
    case AttributeKind::tp: {
      attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_TYPE_PROTO);
      auto tp = attr->mutable_tp();
      tp->CopyFrom(n->tp(name));
    } break;
    case AttributeKind::tps: {
      attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_TYPE_PROTOS);
      for (auto& v : n->tps(name)) {
        auto tp = attr->add_type_protos();
        tp->CopyFrom(v);
      }
    } break;
  }
}

static void encodeTypeProtoTensorType(ONNX_NAMESPACE::TypeProto_Tensor* tensor_type, const Value* n) {
  if (n->elemType() != 0) {
    tensor_type->set_elem_type(n->elemType());
  }
  if (n->has_sizes()) {
    ONNX_NAMESPACE::TensorShapeProto* shape = tensor_type->mutable_shape();
    for (const Dimension& d : n->sizes()) {
      auto dim = shape->add_dim();
      if (!d.is_unknown) {
        if (d.is_int) {
          dim->set_dim_value(d.dim);
        } else {
          dim->set_dim_param(d.param);
        }
      }
    }
  }
}

static void encodeValueInfo(ONNX_NAMESPACE::ValueInfoProto* v, Value* n) {
  v->set_name(value_name(n));
  if (n->elemType() != 0 || n->has_sizes()) {
    ONNX_NAMESPACE::TypeProto* t = v->mutable_type();
    ONNX_NAMESPACE::TypeProto_Tensor* tensor_type = t->mutable_tensor_type();
    encodeTypeProtoTensorType(tensor_type, n);
  }
}

void encodeGraph(GraphProto* p_g, const std::shared_ptr<Graph>& g) {
  ONNX_ASSERT(p_g != nullptr)

  if (g->has_name()) {
    p_g->set_name(g->name());
  }

  if (g->has_doc_string()) {
    p_g->set_doc_string(g->docString());
  }

  for (auto input : g->inputs()) {
    ONNX_NAMESPACE::ValueInfoProto* v = p_g->add_input();
    encodeValueInfo(v, input);
  }
  for (auto output : g->outputs()) {
    ONNX_NAMESPACE::ValueInfoProto* v = p_g->add_output();
    encodeValueInfo(v, output);
  }

  std::unordered_set<Value*> graph_outputs(g->outputs().begin(), g->outputs().end());

  for (auto node : g->nodes()) {
    if (node->kind() == kUndefined || node->kind() == kCaptured) {
      // Undefined nodes are used to represent optional inputs that are not
      // provided.
      continue;
    }
    auto p_n = p_g->add_node();
    for (auto input : node->inputs()) {
      if (input->node()->kind() == kUndefined) {
        p_n->add_input("");
      } else {
        p_n->add_input(value_name(input));
      }
    }
    for (auto output : node->outputs()) {
      p_n->add_output(value_name(output));
      // only save it if
      //  - it has actual information worth saving
      //  - it's not already saved in the graph outputs value info
      if (graph_outputs.find(output) != graph_outputs.end()) {
        continue;
      }
      if (output->elemType() == TensorProto_DataType_UNDEFINED && output->sizes().empty()) {
        continue;
      }
      ValueInfoProto* v = p_g->add_value_info();
      encodeValueInfo(v, output);
    }
    p_n->set_op_type(node->kind().toString());
    for (auto attr_name : node->attributeNames()) {
      addAttribute(p_n, node, attr_name);
    }
    if (node->has_doc_string()) {
      p_n->set_doc_string(node->docString());
    }
    if (node->has_name()) {
      p_n->set_name(node->name());
    }
    if (node->has_domain()) {
      p_n->set_domain(node->domain());
    }
    if (node->has_overload()) {
      p_n->set_overload(node->overload());
    }
  }

  auto num_initializers = g->initializers().size();
  for (unsigned int i = 0; i < num_initializers; i++) {
    auto p = p_g->add_initializer();
    p->set_name(g->initializer_names()[i]);
    encodeTensor(p, g->initializers()[i]);
  }
}

void ExportModelProto(ModelProto* p_m, const std::shared_ptr<Graph>& g) {
  GraphProto* p_g = p_m->mutable_graph();
  encodeGraph(p_g, g);
  // Add new opset_versions
  p_m->clear_opset_import();
  for (const OpSetID& opset : g->opset_versions_mutable()) {
    OperatorSetIdProto* opset_version_output = p_m->add_opset_import();
    opset_version_output->set_domain(opset.domain());
    opset_version_output->set_version(opset.version());
  }
}

ModelProto PrepareOutput(const ModelProto& mp_in) {
  ModelProto mp_out{};

  if (mp_in.has_ir_version()) {
    mp_out.set_ir_version(mp_in.ir_version());
  }
  if (mp_in.has_producer_name()) {
    mp_out.set_producer_name(mp_in.producer_name());
  }
  if (mp_in.has_producer_version()) {
    mp_out.set_producer_version(mp_in.producer_version());
  }
  if (mp_in.has_domain()) {
    mp_out.set_domain(mp_in.domain());
  }
  if (mp_in.has_model_version()) {
    mp_out.set_model_version(mp_in.model_version());
  }
  if (mp_in.has_doc_string()) {
    mp_out.set_doc_string(mp_in.doc_string());
  }
  for (int i = 0; i < mp_in.opset_import_size(); i++) {
    auto& oi_in = mp_in.opset_import(i);
    auto* oi_out = mp_out.add_opset_import();
    if (oi_in.has_domain()) {
      oi_out->set_domain(oi_in.domain());
    }
    if (oi_in.has_version()) {
      oi_out->set_version(oi_in.version());
    }
  }
  for (int i = 0; i < mp_in.metadata_props_size(); i++) {
    auto& pp_in = mp_in.metadata_props(i);
    auto* pp_out = mp_out.add_metadata_props();
    if (pp_in.has_key()) {
      pp_out->set_key(pp_in.key());
    }
    if (pp_in.has_value()) {
      pp_out->set_value(pp_in.value());
    }
  }

  return mp_out;
}

void assertNonNull(const std::shared_ptr<Graph>& g) {
  ONNX_ASSERTM(
      g.get() != nullptr,
      "Warning: onnx version converter is unable to parse input model. "
      "(The IR version of the ONNX model may be too old.)")
}

} // namespace ONNX_NAMESPACE
