/*
 * SPDX-License-Identifier: Apache-2.0
 */

#include "onnx/defs/printer.h"

#include <iomanip>

#include "onnx/defs/tensor_proto_util.h"

namespace ONNX_NAMESPACE {

using StringStringEntryProtos = google::protobuf::RepeatedPtrField<StringStringEntryProto>;

static bool IsValidIdentifier(const std::string& str) {
  // Check if str is a valid identifier
  const char* next_ = str.c_str();
  const char* end_ = next_ + str.size();
  if (next_ == end_)
    return false; // empty string is not a valid identifier
  if (!isalpha(*next_) && (*next_ != '_'))
    return false; // first character must be a letter or '_'
  ++next_;
  while ((next_ < end_) && (isalnum(*next_) || (*next_ == '_')))
    ++next_;
  return next_ == end_;
}

class ProtoPrinter {
 public:
  explicit ProtoPrinter(std::ostream& os) : output_(os) {}

  void print(const TensorShapeProto_Dimension& dim);

  void print(const TensorShapeProto& shape);

  void print(const TypeProto_Tensor& tensortype);

  void print(const TypeProto& type);

  void print(const TypeProto_Sequence& seqType);

  void print(const TypeProto_Map& mapType);

  void print(const TypeProto_Optional& optType);

  void print(const TypeProto_SparseTensor& sparseType);

  void print(const TensorProto& tensor, bool is_initializer = false);

  void print(const ValueInfoProto& value_info);

  void print(const ValueInfoList& vilist);

  void print(const AttributeProto& attr);

  void print(const AttrList& attrlist);

  void print(const NodeProto& node);

  void print(const NodeList& nodelist);

  void print(const GraphProto& graph);

  void print(const FunctionProto& fn);

  void print(const ModelProto& model);

  void print(const OperatorSetIdProto& opset);

  void print(const OpsetIdList& opsets);

  void print(const StringStringEntryProtos& stringStringProtos) {
    printSet("[", ", ", "]", stringStringProtos);
  }

  void print(const StringStringEntryProto& metadata) {
    printQuoted(metadata.key());
    output_ << ": ";
    printQuoted(metadata.value());
  }

 private:
  void printId(const std::string& str) {
    if (IsValidIdentifier(str))
      output_ << str;
    else
      printQuoted(str);
  }

  template <typename T>
  void print(const T& prim) {
    output_ << prim;
  }

  void printQuoted(const std::string& str) {
    output_ << "\"";
    for (const char* p = str.c_str(); *p; ++p) {
      if ((*p == '\\') || (*p == '"'))
        output_ << '\\';
      output_ << *p;
    }
    output_ << "\"";
  }

  template <typename T>
  void printKeyValuePair(KeyWordMap::KeyWord key, const T& val, bool addsep = true) {
    if (addsep)
      output_ << "," << '\n';
    output_ << std::setw(indent_level) << ' ' << KeyWordMap::ToString(key) << ": ";
    print(val);
  }

  void printKeyValuePair(KeyWordMap::KeyWord key, const std::string& val) {
    output_ << "," << '\n';
    output_ << std::setw(indent_level) << ' ' << KeyWordMap::ToString(key) << ": ";
    printQuoted(val);
  }

  template <typename Collection>
  void printSet(const char* open, const char* separator, const char* close, const Collection& coll) {
    const char* sep = "";
    output_ << open;
    for (auto& elt : coll) {
      output_ << sep;
      print(elt);
      sep = separator;
    }
    output_ << close;
  }

  template <typename Collection>
  void printIdSet(const char* open, const char* separator, const char* close, const Collection& coll) {
    const char* sep = "";
    output_ << open;
    for (auto& elt : coll) {
      output_ << sep;
      printId(elt);
      sep = separator;
    }
    output_ << close;
  }

  std::ostream& output_;
  int indent_level = 3;

  void indent() {
    indent_level += 3;
  }

  void outdent() {
    indent_level -= 3;
  }
};

void ProtoPrinter::print(const TensorShapeProto_Dimension& dim) {
  if (dim.has_dim_value())
    output_ << dim.dim_value();
  else if (dim.has_dim_param())
    output_ << dim.dim_param();
  else
    output_ << "?";
}

void ProtoPrinter::print(const TensorShapeProto& shape) {
  printSet("[", ",", "]", shape.dim());
}

void ProtoPrinter::print(const TypeProto_Tensor& tensortype) {
  output_ << PrimitiveTypeNameMap::ToString(tensortype.elem_type());
  if (tensortype.has_shape()) {
    if (tensortype.shape().dim_size() > 0)
      print(tensortype.shape());
  } else
    output_ << "[]";
}

void ProtoPrinter::print(const TypeProto_Sequence& seqType) {
  output_ << "seq(";
  print(seqType.elem_type());
  output_ << ")";
}

void ProtoPrinter::print(const TypeProto_Map& mapType) {
  output_ << "map(" << PrimitiveTypeNameMap::ToString(mapType.key_type()) << ", ";
  print(mapType.value_type());
  output_ << ")";
}

void ProtoPrinter::print(const TypeProto_Optional& optType) {
  output_ << "optional(";
  print(optType.elem_type());
  output_ << ")";
}

void ProtoPrinter::print(const TypeProto_SparseTensor& sparseType) {
  output_ << "sparse_tensor(" << PrimitiveTypeNameMap::ToString(sparseType.elem_type());
  if (sparseType.has_shape()) {
    if (sparseType.shape().dim_size() > 0)
      print(sparseType.shape());
  } else
    output_ << "[]";

  output_ << ")";
}

void ProtoPrinter::print(const TypeProto& type) {
  if (type.has_tensor_type())
    print(type.tensor_type());
  else if (type.has_sequence_type())
    print(type.sequence_type());
  else if (type.has_map_type())
    print(type.map_type());
  else if (type.has_optional_type())
    print(type.optional_type());
  else if (type.has_sparse_tensor_type())
    print(type.sparse_tensor_type());
}

void ProtoPrinter::print(const TensorProto& tensor, bool is_initializer) {
  output_ << PrimitiveTypeNameMap::ToString(tensor.data_type());
  if (tensor.dims_size() > 0)
    printSet("[", ",", "]", tensor.dims());

  if (!tensor.name().empty()) {
    output_ << " ";
    printId(tensor.name());
  }
  if (is_initializer) {
    output_ << " = ";
  }
  // TODO: does not yet handle all types
  if (tensor.has_data_location() && tensor.data_location() == TensorProto_DataLocation_EXTERNAL) {
    print(tensor.external_data());
  } else if (tensor.has_raw_data()) {
    switch (static_cast<TensorProto::DataType>(tensor.data_type())) {
      case TensorProto::DataType::TensorProto_DataType_INT32:
        printSet(" {", ",", "}", ParseData<int32_t>(&tensor));
        break;
      case TensorProto::DataType::TensorProto_DataType_INT64:
        printSet(" {", ",", "}", ParseData<int64_t>(&tensor));
        break;
      case TensorProto::DataType::TensorProto_DataType_FLOAT:
        printSet(" {", ",", "}", ParseData<float>(&tensor));
        break;
      case TensorProto::DataType::TensorProto_DataType_DOUBLE:
        printSet(" {", ",", "}", ParseData<double>(&tensor));
        break;
      default:
        output_ << "..."; // ParseData not instantiated for other types.
        break;
    }
  } else {
    switch (static_cast<TensorProto::DataType>(tensor.data_type())) {
      case TensorProto::DataType::TensorProto_DataType_INT8:
      case TensorProto::DataType::TensorProto_DataType_INT16:
      case TensorProto::DataType::TensorProto_DataType_INT32:
      case TensorProto::DataType::TensorProto_DataType_UINT8:
      case TensorProto::DataType::TensorProto_DataType_UINT16:
      case TensorProto::DataType::TensorProto_DataType_BOOL:
        printSet(" {", ",", "}", tensor.int32_data());
        break;
      case TensorProto::DataType::TensorProto_DataType_INT64:
        printSet(" {", ",", "}", tensor.int64_data());
        break;
      case TensorProto::DataType::TensorProto_DataType_UINT32:
      case TensorProto::DataType::TensorProto_DataType_UINT64:
        printSet(" {", ",", "}", tensor.uint64_data());
        break;
      case TensorProto::DataType::TensorProto_DataType_FLOAT:
        printSet(" {", ",", "}", tensor.float_data());
        break;
      case TensorProto::DataType::TensorProto_DataType_DOUBLE:
        printSet(" {", ",", "}", tensor.double_data());
        break;
      case TensorProto::DataType::TensorProto_DataType_STRING: {
        const char* sep = "{";
        for (auto& elt : tensor.string_data()) {
          output_ << sep;
          printQuoted(elt);
          sep = ", ";
        }
        output_ << "}";
        break;
      }
      default:
        break;
    }
  }
}

void ProtoPrinter::print(const ValueInfoProto& value_info) {
  print(value_info.type());
  output_ << " ";
  printId(value_info.name());
}

void ProtoPrinter::print(const ValueInfoList& vilist) {
  printSet("(", ", ", ")", vilist);
}

void ProtoPrinter::print(const AttributeProto& attr) {
  // Special case of attr-ref:
  if (attr.has_ref_attr_name()) {
    output_ << attr.name() << ": " << AttributeTypeNameMap::ToString(attr.type()) << " = @" << attr.ref_attr_name();
    return;
  }
  // General case:
  output_ << attr.name() << ": " << AttributeTypeNameMap::ToString(attr.type()) << " = ";
  switch (attr.type()) {
    case AttributeProto_AttributeType_INT:
      output_ << attr.i();
      break;
    case AttributeProto_AttributeType_INTS:
      printSet("[", ", ", "]", attr.ints());
      break;
    case AttributeProto_AttributeType_FLOAT:
      output_ << attr.f();
      break;
    case AttributeProto_AttributeType_FLOATS:
      printSet("[", ", ", "]", attr.floats());
      break;
    case AttributeProto_AttributeType_STRING:
      output_ << "\"" << attr.s() << "\"";
      break;
    case AttributeProto_AttributeType_STRINGS: {
      const char* sep = "[";
      for (auto& elt : attr.strings()) {
        output_ << sep << "\"" << elt << "\"";
        sep = ", ";
      }
      output_ << "]";
      break;
    }
    case AttributeProto_AttributeType_GRAPH:
      indent();
      print(attr.g());
      outdent();
      break;
    case AttributeProto_AttributeType_GRAPHS:
      indent();
      printSet("[", ", ", "]", attr.graphs());
      outdent();
      break;
    case AttributeProto_AttributeType_TENSOR:
      print(attr.t());
      break;
    case AttributeProto_AttributeType_TENSORS:
      printSet("[", ", ", "]", attr.tensors());
      break;
    case AttributeProto_AttributeType_TYPE_PROTO:
      print(attr.tp());
      break;
    case AttributeProto_AttributeType_TYPE_PROTOS:
      printSet("[", ", ", "]", attr.type_protos());
      break;
    default:
      break;
  }
}

void ProtoPrinter::print(const AttrList& attrlist) {
  printSet(" <", ", ", ">", attrlist);
}

void ProtoPrinter::print(const NodeProto& node) {
  output_ << std::setw(indent_level) << ' ';
  if (node.has_name()) {
    output_ << "[";
    printId(node.name());
    output_ << "] ";
  }
  printIdSet("", ", ", "", node.output());
  output_ << " = ";
  if (!node.domain().empty())
    output_ << node.domain() << ".";
  output_ << node.op_type();
  if (!node.overload().empty())
    output_ << ":" << node.overload();
  bool has_subgraph = false;
  for (const auto& attr : node.attribute())
    if (attr.has_g() || (attr.graphs_size() > 0))
      has_subgraph = true;
  if ((!has_subgraph) && (node.attribute_size() > 0))
    print(node.attribute());
  printIdSet(" (", ", ", ")", node.input());
  if ((has_subgraph) && (node.attribute_size() > 0))
    print(node.attribute());
  output_ << "\n";
}

void ProtoPrinter::print(const NodeList& nodelist) {
  output_ << "{\n";
  for (auto& node : nodelist) {
    print(node);
  }
  if (indent_level > 3)
    output_ << std::setw(indent_level - 3) << "   ";
  output_ << "}";
}

void ProtoPrinter::print(const GraphProto& graph) {
  printId(graph.name());
  output_ << " " << graph.input() << " => " << graph.output() << " ";
  if ((graph.initializer_size() > 0) || (graph.value_info_size() > 0)) {
    output_ << '\n' << std::setw(indent_level) << ' ' << '<';
    const char* sep = "";
    for (auto& init : graph.initializer()) {
      output_ << sep;
      print(init, true);
      sep = ", ";
    }
    for (auto& vi : graph.value_info()) {
      output_ << sep;
      print(vi);
      sep = ", ";
    }
    output_ << ">" << '\n';
  }
  print(graph.node());
}

void ProtoPrinter::print(const ModelProto& model) {
  output_ << "<\n";
  printKeyValuePair(KeyWordMap::KeyWord::IR_VERSION, model.ir_version(), false);
  printKeyValuePair(KeyWordMap::KeyWord::OPSET_IMPORT, model.opset_import());
  if (model.has_producer_name())
    printKeyValuePair(KeyWordMap::KeyWord::PRODUCER_NAME, model.producer_name());
  if (model.has_producer_version())
    printKeyValuePair(KeyWordMap::KeyWord::PRODUCER_VERSION, model.producer_version());
  if (model.has_domain())
    printKeyValuePair(KeyWordMap::KeyWord::DOMAIN_KW, model.domain());
  if (model.has_model_version())
    printKeyValuePair(KeyWordMap::KeyWord::MODEL_VERSION, model.model_version());
  if (model.has_doc_string())
    printKeyValuePair(KeyWordMap::KeyWord::DOC_STRING, model.doc_string());
  if (model.metadata_props_size() > 0)
    printKeyValuePair(KeyWordMap::KeyWord::METADATA_PROPS, model.metadata_props());
  output_ << '\n' << ">" << '\n';

  print(model.graph());
  for (const auto& fn : model.functions()) {
    output_ << '\n';
    print(fn);
  }
}

void ProtoPrinter::print(const OperatorSetIdProto& opset) {
  output_ << "\"" << opset.domain() << "\" : " << opset.version();
}

void ProtoPrinter::print(const OpsetIdList& opsets) {
  printSet("[", ", ", "]", opsets);
}

void ProtoPrinter::print(const FunctionProto& fn) {
  output_ << "<\n";
  output_ << "  "
          << "domain: \"" << fn.domain() << "\",\n";
  if (!fn.overload().empty())
    output_ << "  "
            << "overload: \"" << fn.overload() << "\",\n";

  output_ << "  "
          << "opset_import: ";
  printSet("[", ",", "]", fn.opset_import());
  output_ << "\n>\n";
  printId(fn.name());
  output_ << " ";
  if (fn.attribute_size() > 0)
    printSet("<", ",", ">", fn.attribute());
  printIdSet("(", ", ", ")", fn.input());
  output_ << " => ";
  printIdSet("(", ", ", ")", fn.output());
  output_ << "\n";
  print(fn.node());
}

#define DEF_OP(T)                                              \
  std::ostream& operator<<(std::ostream& os, const T& proto) { \
    ProtoPrinter printer(os);                                  \
    printer.print(proto);                                      \
    return os;                                                 \
  };

DEF_OP(TensorShapeProto_Dimension)

DEF_OP(TensorShapeProto)

DEF_OP(TypeProto_Tensor)

DEF_OP(TypeProto)

DEF_OP(TensorProto)

DEF_OP(ValueInfoProto)

DEF_OP(ValueInfoList)

DEF_OP(AttributeProto)

DEF_OP(AttrList)

DEF_OP(NodeProto)

DEF_OP(NodeList)

DEF_OP(GraphProto)

DEF_OP(FunctionProto)

DEF_OP(ModelProto)

} // namespace ONNX_NAMESPACE
