/*
 * SPDX-License-Identifier: Apache-2.0
 */

// Experimental language syntax and parser for ONNX. Please note that the syntax as formalized
// by this parser is preliminary and may change.

#include "onnx/defs/parser.h"

#include <cctype>
#include <limits>
#include <string>

#include "onnx/common/common.h"

#define PARSE_TOKEN(x) CHECK_PARSER_STATUS(ParserBase::Parse(x))
#define PARSE(...) CHECK_PARSER_STATUS(Parse(__VA_ARGS__))
#define MATCH(...) CHECK_PARSER_STATUS(Match(__VA_ARGS__))

namespace ONNX_NAMESPACE {

Status ParserBase::Parse(Literal& result) {
  bool decimal_point = false;
  auto nextch = NextChar();
  auto from = next_;
  if (nextch == '"') {
    ++next_;
    bool has_escape = false;
    while ((next_ < end_) && (*next_ != '"')) {
      if (*next_ == '\\') {
        has_escape = true;
        ++next_;
        if (next_ >= end_)
          return ParseError("Incomplete string literal.");
      }
      ++next_;
    }
    if (next_ >= end_)
      return ParseError("Incomplete string literal.");
    ++next_;
    result.type = LiteralType::STRING_LITERAL;
    if (has_escape) {
      std::string& target = result.value;
      target.clear();
      target.reserve(next_ - from - 2); // upper bound
      // *from is the starting quote. *(next_-1) is the ending quote.
      // Copy what is in-between, except for the escape character
      while (++from < next_ - 1) {
        // Copy current char, if not escape, or next char otherwise.
        target.push_back(*from != '\\' ? (*from) : *(++from));
      }
    } else
      result.value = std::string(from + 1, next_ - from - 2); // skip enclosing quotes
    return Status::OK();
  }

  // Simplify the next ifs by consuming a possible negative sign.
  if (nextch == '-') {
    ++next_;
    nextch = NextChar();
  }

  // Check for float literals that start with alphabet characters.
  if (isalpha(nextch)) {
    // Has to be a special float literal now: (-)*(nan|inf|infinity).
    if (NextIsValidFloatString()) {
      while (next_ < end_ && isalpha(*next_)) {
        ++next_;
      }
      ONNX_TRY {
        static_cast<void>(std::stof(std::string(from, next_ - from)));
        result.type = LiteralType::FLOAT_LITERAL;
        result.value = std::string(from, next_ - from);
      }
      ONNX_CATCH(...) {
        ONNX_HANDLE_EXCEPTION([&]() { return ParseError("Encountered invalid float literal!"); });
      }
    } else {
      return ParseError("Encountered invalid float literal!");
    }
    return Status::OK();
  }

  // Checking for numeric ints or float literal.
  if (isdigit(nextch)) {
    ++next_;

    while ((next_ < end_) && (isdigit(*next_) || (*next_ == '.'))) {
      if (*next_ == '.') {
        if (decimal_point)
          break; // Only one decimal point allowed in numeric literal
        decimal_point = true;
      }
      ++next_;
    }

    if (next_ == from)
      return ParseError("Value expected but not found.");

    // Optional exponent syntax: (e|E)(+|-)?[0-9]+
    if ((next_ < end_) && ((*next_ == 'e') || (*next_ == 'E'))) {
      decimal_point = true; // treat as float-literal
      ++next_;
      if ((next_ < end_) && ((*next_ == '+') || (*next_ == '-')))
        ++next_;
      while ((next_ < end_) && (isdigit(*next_)))
        ++next_;
    }

    result.value = std::string(from, next_ - from);
    result.type = decimal_point ? LiteralType::FLOAT_LITERAL : LiteralType::INT_LITERAL;
  }
  return Status::OK();
}

bool ParserBase::NextIsValidFloatString() {
  auto nextch = NextChar();
  auto from = next_;
  constexpr int INFINITY_LENGTH = 8;

  if (isalpha(nextch)) {
    while (next_ < end_ && isalpha(*next_) && (next_ - from) <= INFINITY_LENGTH) {
      ++next_;
    }

    if (isdigit(*next_)) { // No trailing digits
      next_ = from;
      return false;
    }

    std::string candidate = std::string(from, next_ - from);

    // Reset parser location before continuing.
    next_ = from;

    std::transform(
        candidate.begin(), candidate.end(), candidate.begin(), [](unsigned char c) { return std::tolower(c); });
    if (candidate == std::string("inf") || candidate == std::string("infinity") || candidate == std::string("nan")) {
      return true;
    }
  }
  return false;
}

// Parsing an IdList (list of identifiers separated by commas, where identifiers are allowed to be empty).
// Used to represent the list of inputs or outputs of a node.
// An empty identifier may be represented by an empty string "" or by nothing followed by a single comma.
// "Op()" has no operands
// "Op(,x)" has two operands, the first being empty.
// 'Op("")' has one operand, which is an empty string.
// 'Op(,)' has one operand, which is an empty string.
// Thus, this will also allow a trailing comma after a non-empty identifier with no effect.
// 'Op(x,)' has one operand, which is 'x'.
//
// This is mostly for some backward compatibility. "" is a simpler way to represent an
// empty identifier that is less confusing and is recommended.

Status OnnxParser::Parse(IdList& idlist) {
  idlist.Clear();
  std::string id;
  bool found = false;
  CHECK_PARSER_STATUS(ParseOptionalQuotableIdentifier(id, found));
  if (!found)
    return Status::OK();
  *idlist.Add() = id;
  while (Matches(',')) {
    CHECK_PARSER_STATUS(ParseOptionalQuotableIdentifier(id, found));
    if (!found)
      break;
    *idlist.Add() = id;
  }
  return Status::OK();
}

Status OnnxParser::Parse(char open, IdList& idlist, char close) {
  idlist.Clear();
  if (Matches(open)) {
    PARSE(idlist);
    MATCH(close);
  }
  return Status::OK();
}

Status OnnxParser::Parse(IdList& idlist, AttrList& attrlist) {
  idlist.Clear();
  attrlist.Clear();
  do {
    std::string id;
    CHECK_PARSER_STATUS(ParseQuotableIdentifier(id));
    auto next = NextChar();
    if (next == ':' || next == '=')
      Parse(*attrlist.Add(), id);
    else
      *idlist.Add() = id;
  } while (Matches(','));
  return Status::OK();
}

Status OnnxParser::Parse(char open, IdList& idlist, AttrList& attrlist, char close) {
  if (Matches(open)) {
    PARSE(idlist, attrlist);
    MATCH(close);
  } else {
    idlist.Clear();
    attrlist.Clear();
  }
  return Status::OK();
}

Status OnnxParser::Parse(TensorShapeProto& shape) {
  shape.clear_dim();
  do {
    if (Matches('?')) {
      shape.add_dim();
    } else {
      // Check for a symbolic identifier ...
      auto id = ParseOptionalIdentifier();
      if (!id.empty()) {
        shape.add_dim()->set_dim_param(id);
      } else {
        // ...or a integer value
        int64_t dimval = 0;
        PARSE_TOKEN(dimval);
        shape.add_dim()->set_dim_value(dimval);
      }
    }
  } while (Matches(','));
  return Status::OK();
}

Status OnnxParser::Parse(TypeProto& typeProto) {
  std::string id;
  CHECK_PARSER_STATUS(ParseIdentifier(id));
  int dtype = PrimitiveTypeNameMap::Lookup(id);
  if (dtype != 0) {
    auto* tensortype = typeProto.mutable_tensor_type();
    tensortype->set_elem_type(dtype);
    tensortype->clear_shape();
    // Grammar:
    // float indicates scalar (rank 0)
    // float [] indicates unknown rank tensor (not a zero rank tensor)
    // float [one-or-more-dimensions] indicates tensor of known rank > 0.
    if (Matches('[')) {
      if (!Matches(']')) {
        PARSE(*tensortype->mutable_shape());
        MATCH(']');
      }
    } else {
      // Create shape with zero dimensions for scalar
      (void)(tensortype->mutable_shape());
    }
  } else {
    switch (KeyWordMap::Lookup(id)) {
      case KeyWordMap::KeyWord::SEQ_TYPE: {
        // Grammar: seq ( type )
        MATCH('(');
        auto* seqtype = typeProto.mutable_sequence_type();
        PARSE(*seqtype->mutable_elem_type());
        MATCH(')');
        break;
      }
      case KeyWordMap::KeyWord::MAP_TYPE: {
        // Grammar: map ( prim-type , type )
        MATCH('(');
        auto* maptype = typeProto.mutable_map_type();
        CHECK_PARSER_STATUS(ParseIdentifier(id));
        dtype = PrimitiveTypeNameMap::Lookup(id);
        if (dtype == 0) {
          return ParseError("Expecting primitive type as map key type.");
        }
        maptype->set_key_type(dtype);
        MATCH(',');
        PARSE(*maptype->mutable_value_type());
        MATCH(')');
        break;
      }
      case KeyWordMap::KeyWord::OPTIONAL_TYPE: {
        // Grammar: optional ( type )
        MATCH('(');
        auto* opttype = typeProto.mutable_optional_type();
        PARSE(*opttype->mutable_elem_type());
        MATCH(')');
        break;
      }
      case KeyWordMap::KeyWord::SPARSE_TENSOR_TYPE: {
        // Grammar: sparse_tensor ( tensor-type )
        MATCH('(');
        CHECK_PARSER_STATUS(ParseIdentifier(id));
        dtype = PrimitiveTypeNameMap::Lookup(id);
        if (dtype != 0) {
          auto* sparsetype = typeProto.mutable_sparse_tensor_type();
          sparsetype->set_elem_type(dtype);
          sparsetype->clear_shape();
          // Grammar:
          // float indicates scalar (rank 0)
          // float [] indicates unknown rank tensor (not a zero rank tensor)
          // float [one-or-more-dimensions] indicates tensor of known rank > 0.
          if (Matches('[')) {
            if (!Matches(']')) {
              PARSE(*sparsetype->mutable_shape());
              MATCH(']');
            }
          } else {
            // Create shape with zero dimensions for scalar
            (void)(sparsetype->mutable_shape());
          }
        } else {
          return ParseError("Unexpected type in sparse-tensor element type.");
        }
        MATCH(')');
        break;
      }
      default:
        return ParseError("Unexpected type.");
    }
  }
  return Status::OK();
}

Status OnnxParser::Parse(ValueInfoProto& valueinfo) {
  if (NextIsType())
    PARSE(*valueinfo.mutable_type());
  std::string name;
  CHECK_PARSER_STATUS(ParseQuotableIdentifier(name));
  valueinfo.set_name(name);
  return Status::OK();
}

Status OnnxParser::Parse(char open, ValueInfoList& vilist, char close) {
  MATCH(open);
  if (!Matches(close)) {
    do {
      PARSE(*vilist.Add());
    } while (Matches(','));
    MATCH(close);
  }
  return Status::OK();
}

Status OnnxParser::ParseGraphInputOutput(ValueInfoList& vilist) {
  vilist.Clear();
  PARSE('(', vilist, ')');
  return Status::OK();
}

Status OnnxParser::ParseFunctionInputOutput(IdList& idlist, ValueInfoList& vilist) {
  // Do not clear vilist, as it accumulates values over inputs and outputs.
  idlist.Clear();
  MATCH('(');
  if (!Matches(')')) {
    do {
      // Function inputs/outputs can be optionally typed.
      // Syntax: Name | Type Name
      // The name is added to idlist. If the optional type is present, an entry is
      // added to vilist.

      std::string* name = idlist.Add();
      ValueInfoProto* vi = nullptr;

      if (NextIsType()) {
        vi = vilist.Add();
        PARSE(*(vi->mutable_type()));
      }
      CHECK_PARSER_STATUS(ParseQuotableIdentifier(*name));
      if (vi != nullptr)
        vi->set_name(*name);
    } while (Matches(','));
    MATCH(')');
  }
  return Status::OK();
}

// Each input element is a value-info with an optional initializer of the form "= initial-value".
// The value-info is added to the "inputs", while the initializer is added to initializers.
Status OnnxParser::ParseInput(ValueInfoList& inputs, TensorList& initializers) {
  inputs.Clear();
  if (Matches('(')) {
    if (!Matches(')')) {
      do {
        ValueInfoProto vi;
        PARSE(vi);
        *inputs.Add() = vi;
        if (Matches('=')) {
          // default value for input
          TensorProto& tp = *initializers.Add();
          tp.set_name(vi.name());
          CHECK_PARSER_STATUS(Parse(tp, vi.type()));
        }
      } while (Matches(','));
      MATCH(')');
    }
  }
  return Status::OK();
}

// This is handled slightly different from the inputs.
// Each element is either a value-info or an initializer.
// A value-info is added to the "value_infos", while an initializer is added to initializers.
Status OnnxParser::ParseValueInfo(ValueInfoList& value_infos, TensorList& initializers) {
  value_infos.Clear();
  if (Matches('<')) {
    if (!Matches('>')) {
      do {
        ValueInfoProto vi;
        PARSE(vi);
        if (Matches('=')) {
          // initializer
          TensorProto& tp = *initializers.Add();
          tp.set_name(vi.name());
          CHECK_PARSER_STATUS(Parse(tp, vi.type()));
        } else {
          // valueinfo
          *value_infos.Add() = vi;
        }
      } while (Matches(','));
      MATCH('>');
    }
  }
  return Status::OK();
}

Status OnnxParser::Parse(StringStringList& stringStringList) {
  std::string strval;
  do {
    auto* metadata = stringStringList.Add();
    PARSE_TOKEN(strval);
    metadata->set_key(strval);
    MATCH(':');
    PARSE_TOKEN(strval);
    metadata->set_value(strval);
  } while (Matches(','));
  return Status::OK();
}

Status OnnxParser::Parse(TensorProto& tensorProto) {
  tensorProto = TensorProto();
  // Parse the concrete tensor-type with numeric dimensions:
  TypeProto typeProto;
  PARSE(typeProto);
  *tensorProto.mutable_name() = ParseOptionalIdentifier();
  (void)Matches('='); // Optional, to unify handling of initializers as well as tensor-protos in other contexts
  return Parse(tensorProto, typeProto);
}

// Parse TensorProto data given its type:
Status OnnxParser::Parse(TensorProto& tensorProto, const TypeProto& tensorTypeProto) {
  if (!tensorTypeProto.has_tensor_type())
    return ParseError("Error parsing TensorProto (expected a tensor type).");
  auto elem_type = tensorTypeProto.tensor_type().elem_type();
  tensorProto.set_data_type(elem_type);
  if (!tensorTypeProto.tensor_type().has_shape())
    return ParseError("Error parsing TensorProto (expected a tensor shape).");
  for (auto& dim : tensorTypeProto.tensor_type().shape().dim()) {
    if (!dim.has_dim_value())
      return ParseError("Error parsing TensorProto shape (expected numeric dimension).");
    auto dimval = dim.dim_value();
    tensorProto.add_dims(dimval);
  }

  // tensorProto.mutable_int64_data()->Reserve(n);
  // Parse the actual values:

  int64_t intval = 0;
  uint64_t uintval = 0;
  float floatval = 0.0;
  double dblval = 0.0;
  std::string strval;
  if (Matches('{')) {
    if (!Matches('}')) {
      do {
        switch (static_cast<TensorProto::DataType>(elem_type)) {
          case TensorProto::DataType::TensorProto_DataType_INT2:
          case TensorProto::DataType::TensorProto_DataType_INT4:
          case TensorProto::DataType::TensorProto_DataType_INT8:
          case TensorProto::DataType::TensorProto_DataType_INT16:
          case TensorProto::DataType::TensorProto_DataType_INT32:
          case TensorProto::DataType::TensorProto_DataType_UINT2:
          case TensorProto::DataType::TensorProto_DataType_UINT4:
          case TensorProto::DataType::TensorProto_DataType_UINT8:
          case TensorProto::DataType::TensorProto_DataType_UINT16:
          case TensorProto::DataType::TensorProto_DataType_FLOAT16:
          case TensorProto::DataType::TensorProto_DataType_BFLOAT16:
          case TensorProto::DataType::TensorProto_DataType_FLOAT8E4M3FN:
          case TensorProto::DataType::TensorProto_DataType_FLOAT8E4M3FNUZ:
          case TensorProto::DataType::TensorProto_DataType_FLOAT8E5M2:
          case TensorProto::DataType::TensorProto_DataType_FLOAT8E5M2FNUZ:
          case TensorProto::DataType::TensorProto_DataType_FLOAT8E8M0:
          case TensorProto::DataType::TensorProto_DataType_BOOL:
          case TensorProto::DataType::TensorProto_DataType_FLOAT4E2M1:
            PARSE_TOKEN(intval);
            if (intval > std::numeric_limits<int32_t>::max() || intval < std::numeric_limits<int32_t>::min()) {
              return ParseError("Mismatch between data type and value: %d, %d", elem_type, intval);
            }
            // NOLINTNEXTLINE(bugprone-narrowing-conversions)
            tensorProto.add_int32_data(intval);
            break;
          case TensorProto::DataType::TensorProto_DataType_INT64:
            PARSE_TOKEN(intval);
            tensorProto.add_int64_data(intval);
            break;
          case TensorProto::DataType::TensorProto_DataType_UINT32:
          case TensorProto::DataType::TensorProto_DataType_UINT64:
            PARSE_TOKEN(uintval);
            tensorProto.add_uint64_data(uintval);
            break;
          case TensorProto::DataType::TensorProto_DataType_COMPLEX64:
          case TensorProto::DataType::TensorProto_DataType_FLOAT:
            PARSE_TOKEN(floatval);
            tensorProto.add_float_data(floatval);
            break;
          case TensorProto::DataType::TensorProto_DataType_COMPLEX128:
          case TensorProto::DataType::TensorProto_DataType_DOUBLE:
            PARSE_TOKEN(dblval);
            tensorProto.add_double_data(dblval);
            break;
          case TensorProto::DataType::TensorProto_DataType_STRING:
            PARSE_TOKEN(strval);
            tensorProto.add_string_data(strval);
            break;
          default:
            return ParseError("Unhandled type: %d", elem_type);
        }
      } while (Matches(','));
      MATCH('}');
    }
  } else if (Matches('[')) {
    tensorProto.set_data_location(TensorProto::DataLocation::TensorProto_DataLocation_EXTERNAL);
    auto& externalData = *tensorProto.mutable_external_data();
    PARSE(externalData);
    MATCH(']');
  }
  return Status::OK();
}

bool OnnxParser::NextIsIdentifier() {
  auto id = PeekIdentifier();
  return !(id.empty());
}

bool OnnxParser::NextIsType() {
  auto id = PeekIdentifier();
  if (PrimitiveTypeNameMap::IsTypeName(id))
    return true;
  switch (KeyWordMap::Lookup(id)) {
    case KeyWordMap::KeyWord::SEQ_TYPE:
    case KeyWordMap::KeyWord::MAP_TYPE:
    case KeyWordMap::KeyWord::OPTIONAL_TYPE:
    case KeyWordMap::KeyWord::SPARSE_TENSOR_TYPE:
      return true;
    default:
      return false;
  }
}

Status OnnxParser::ParseSingleAttributeValue(AttributeProto& attr, AttributeProto_AttributeType expected) {
  // Parse a single-value
  auto next = NextChar();
  if (isalpha(next) || next == '_') {
    if (NextIsType()) {
      TypeProto typeProto;
      CHECK_PARSER_STATUS(Parse(typeProto));
      next = NextChar();
      if ((next == '{') || (next == '=') || (NextIsIdentifier())) {
        attr.set_type(AttributeProto_AttributeType_TENSOR);
        auto& tensorProto = *attr.mutable_t();
        CHECK_PARSER_STATUS(ParseOptionalQuotableIdentifier(*tensorProto.mutable_name()));
        (void)Matches('='); // Optional, to unify handling of initializers
        CHECK_PARSER_STATUS(Parse(tensorProto, typeProto));
      } else {
        attr.set_type(AttributeProto_AttributeType_TYPE_PROTO);
        attr.mutable_tp()->CopyFrom(typeProto);
      }
    } else {
      if (NextIsValidFloatString()) {
        Literal literal;
        PARSE_TOKEN(literal);
        attr.set_type(AttributeProto_AttributeType_FLOAT);
        attr.set_f(std::stof(literal.value));
      } else {
        attr.set_type(AttributeProto_AttributeType_GRAPH);
        PARSE(*attr.mutable_g());
      }
    }
  } else if (Matches('@')) {
    std::string name;
    CHECK_PARSER_STATUS(ParseQuotableIdentifier(name));
    attr.set_ref_attr_name(name);
  } else {
    Literal literal;
    PARSE_TOKEN(literal);
    switch (literal.type) {
      case LiteralType::UNDEFINED:
        return ParseError("Internal error");
      case LiteralType::INT_LITERAL:
        attr.set_type(AttributeProto_AttributeType_INT);
        attr.set_i(std::stol(literal.value));
        break;
      case LiteralType::FLOAT_LITERAL:
        attr.set_type(AttributeProto_AttributeType_FLOAT);
        attr.set_f(std::stof(literal.value));
        break;
      case LiteralType::STRING_LITERAL:
        attr.set_type(AttributeProto_AttributeType_STRING);
        attr.set_s(literal.value);
        break;
    }
  }
  if ((expected != AttributeProto_AttributeType_UNDEFINED) && (expected != attr.type())) {
    // Mismatch between type-annotation and attribute-value. We do an implicit cast
    // only in the special case of FLOAT type and integral value like 2
    if ((expected == AttributeProto_AttributeType_FLOAT) && (attr.type() == AttributeProto_AttributeType_INT)) {
      attr.set_type(AttributeProto_AttributeType_FLOAT);
      attr.set_f(static_cast<float>(attr.i()));
    } else {
      return ParseError(
          "Mismatch between expected type ",
          AttributeProto_AttributeType_Name(expected),
          " and specified value's type",
          AttributeProto_AttributeType_Name(attr.type()));
    }
  }
  return Status::OK();
}

Status OnnxParser::Parse(AttributeProto& attr) {
  attr.Clear();
  std::string name;
  CHECK_PARSER_STATUS(ParseIdentifier(name));
  return Parse(attr, name);
}

static bool IsSingletonAttribute(AttributeProto_AttributeType type) {
  switch (type) {
    case AttributeProto_AttributeType_FLOAT:
    case AttributeProto_AttributeType_INT:
    case AttributeProto_AttributeType_STRING:
    case AttributeProto_AttributeType_TENSOR:
    case AttributeProto_AttributeType_GRAPH:
    case AttributeProto_AttributeType_SPARSE_TENSOR:
    case AttributeProto_AttributeType_TYPE_PROTO:
      return true;
    default:
      return false;
  }
}

static AttributeProto_AttributeType ToSingletonType(AttributeProto_AttributeType type) {
  switch (type) {
    case AttributeProto_AttributeType_FLOATS:
      return AttributeProto_AttributeType_FLOAT;
    case AttributeProto_AttributeType_INTS:
      return AttributeProto_AttributeType_INT;
    case AttributeProto_AttributeType_STRINGS:
      return AttributeProto_AttributeType_STRING;
    case AttributeProto_AttributeType_TENSORS:
      return AttributeProto_AttributeType_TENSOR;
    case AttributeProto_AttributeType_GRAPHS:
      return AttributeProto_AttributeType_GRAPH;
    case AttributeProto_AttributeType_SPARSE_TENSORS:
      return AttributeProto_AttributeType_SPARSE_TENSOR;
    case AttributeProto_AttributeType_TYPE_PROTOS:
      return AttributeProto_AttributeType_TYPE_PROTO;
    default:
      return type;
  }
}

Status OnnxParser::Parse(AttributeProto& attr, std::string& name) {
  attr.set_name(name);
  if (Matches(':')) {
    CHECK_PARSER_STATUS(ParseIdentifier(name));
    int attrtype = AttributeTypeNameMap::Lookup(name);
    if (attrtype != 0) {
      attr.set_type(static_cast<AttributeProto_AttributeType>(attrtype));
    } else {
      return ParseError("Unexpected attribute type.");
    }
  }
  MATCH('=');
  if (NextChar() == '[') {
    // Parse a list of values. For an empty list, the type MUST be specified
    // using the type-annotation syntax of ": type".
    MATCH('[');
    if (NextChar() != ']') {
      do {
        AttributeProto nextval;
        auto expected_type = ToSingletonType(attr.type());
        CHECK_PARSER_STATUS(ParseSingleAttributeValue(nextval, expected_type));
        switch (nextval.type()) {
          case AttributeProto_AttributeType_INT:
            attr.set_type(AttributeProto_AttributeType_INTS);
            attr.add_ints(nextval.i());
            break;
          case AttributeProto_AttributeType_FLOAT:
            attr.set_type(AttributeProto_AttributeType_FLOATS);
            attr.add_floats(nextval.f());
            break;
          case AttributeProto_AttributeType_STRING:
            attr.add_strings(nextval.s());
            attr.set_type(AttributeProto_AttributeType_STRINGS);
            break;
          default:
            break;
        }
      } while (Matches(','));
    } else {
      if (attr.type() == AttributeProto_AttributeType_UNDEFINED)
        return ParseError("Empty list attribute value requires type annotation.");
      if (IsSingletonAttribute(attr.type()))
        return ParseError("Singleton attribute value cannot be specified as a list.");
    }
    MATCH(']');
  } else {
    CHECK_PARSER_STATUS(ParseSingleAttributeValue(attr, attr.type()));
  }
  return Status::OK();
}

Status OnnxParser::Parse(AttrList& attrlist) {
  attrlist.Clear();
  if (Matches('<')) {
    do {
      PARSE(*attrlist.Add());
    } while (Matches(','));
    MATCH('>');
  }
  return Status::OK();
}

Status OnnxParser::Parse(NodeProto& node) {
  if (Matches('[')) {
    CHECK_PARSER_STATUS(ParseOptionalQuotableIdentifier(*node.mutable_name()));
    MATCH(']');
  }
  PARSE(*node.mutable_output());
  MATCH('=');
  std::string domain;
  std::string id = ParseOptionalIdentifier();
  while (Matches('.')) {
    if (!domain.empty())
      domain += ".";
    domain += id;
    CHECK_PARSER_STATUS(ParseIdentifier(id));
  }
  node.set_domain(domain);
  node.set_op_type(id);

  if (Matches(':')) {
    std::string overload;
    CHECK_PARSER_STATUS(ParseIdentifier(overload));
    node.set_overload(overload);
  }
  PARSE(*node.mutable_attribute());
  MATCH('(');
  PARSE(*node.mutable_input());
  MATCH(')');
  if (node.attribute_size() == 0) {
    // Permit attributes to be specified before or after parameters.
    PARSE(*node.mutable_attribute());
  }
  return Status::OK();
}

Status OnnxParser::Parse(NodeList& nodelist) {
  nodelist.Clear();
  MATCH('{');
  while (!Matches('}')) {
    PARSE(*nodelist.Add());
  }
  return Status::OK();
}

Status OnnxParser::Parse(GraphProto& graph) {
  std::string id;
  CHECK_PARSER_STATUS(ParseQuotableIdentifier(id));
  return Parse(id, graph);
}

Status OnnxParser::Parse(std::string name, GraphProto& graph) {
  graph.set_name(name);
  graph.mutable_initializer()->Clear();
  CHECK_PARSER_STATUS(ParseInput(*graph.mutable_input(), *graph.mutable_initializer()));
  MATCH('=');
  MATCH('>', false);
  CHECK_PARSER_STATUS(ParseGraphInputOutput(*graph.mutable_output()));
  CHECK_PARSER_STATUS(ParseValueInfo(*graph.mutable_value_info(), *graph.mutable_initializer()));
  return Parse(*graph.mutable_node());
}

Status OnnxParser::Parse(FunctionProto& fn) {
  fn.Clear();
  std::string strval;
  if (Matches('<')) {
    do {
      KeyWordMap::KeyWord keyword = KeyWordMap::KeyWord::NONE;
      PARSE_TOKEN(keyword);
      MATCH(':');
      switch (keyword) {
        case KeyWordMap::KeyWord::OPSET_IMPORT:
          PARSE(*fn.mutable_opset_import());
          break;
        case KeyWordMap::KeyWord::DOC_STRING:
          PARSE_TOKEN(strval);
          fn.set_doc_string(strval);
          break;
        case KeyWordMap::KeyWord::DOMAIN_KW:
          PARSE_TOKEN(strval);
          fn.set_domain(strval);
          break;
        case KeyWordMap::KeyWord::OVERLOAD_KW:
          PARSE_TOKEN(strval);
          fn.set_overload(strval);
          break;
        default:
          return ParseError("Unhandled keyword.");
      }
    } while (Matches(','));
    MATCH('>');
  }
  std::string id;
  CHECK_PARSER_STATUS(ParseQuotableIdentifier(id));
  fn.set_name(id);

  PARSE('<', *fn.mutable_attribute(), *fn.mutable_attribute_proto(), '>');
  fn.mutable_value_info()->Clear();
  CHECK_PARSER_STATUS(ParseFunctionInputOutput(*fn.mutable_input(), *fn.mutable_value_info()));
  MATCH('=');
  MATCH('>', false);
  CHECK_PARSER_STATUS(ParseFunctionInputOutput(*fn.mutable_output(), *fn.mutable_value_info()));
  if (NextChar() == '<') {
    PARSE('<', *fn.mutable_value_info(), '>');
  }
  return Parse(*fn.mutable_node());
}

Status OnnxParser::Parse(OpsetIdList& opsets) {
  std::string strval;
  int64_t intval = 0;
  MATCH('[');
  if (!Matches(']')) {
    do {
      auto* import = opsets.Add();
      PARSE_TOKEN(strval);
      import->set_domain(strval);
      MATCH(':');
      PARSE_TOKEN(intval);
      import->set_version(intval);
    } while (Matches(','));
    MATCH(']');
  }
  return Status::OK();
}

Status OnnxParser::Parse(ModelProto& model) {
  model.Clear();
  std::string strval;
  int64_t intval = 0;
  if (Matches('<')) {
    do {
      KeyWordMap::KeyWord keyword = KeyWordMap::KeyWord::NONE;
      PARSE_TOKEN(keyword);
      MATCH(':');
      switch (keyword) {
        case KeyWordMap::KeyWord::IR_VERSION:
          PARSE_TOKEN(intval);
          model.set_ir_version(intval);
          break;
        case KeyWordMap::KeyWord::OPSET_IMPORT:
          PARSE(*model.mutable_opset_import());
          break;
        case KeyWordMap::KeyWord::PRODUCER_NAME:
          PARSE_TOKEN(strval);
          model.set_producer_name(strval);
          break;
        case KeyWordMap::KeyWord::PRODUCER_VERSION:
          PARSE_TOKEN(strval);
          model.set_producer_version(strval);
          break;
        case KeyWordMap::KeyWord::DOMAIN_KW:
          PARSE_TOKEN(strval);
          model.set_domain(strval);
          break;
        case KeyWordMap::KeyWord::MODEL_VERSION:
          PARSE_TOKEN(intval);
          model.set_model_version(intval);
          break;
        case KeyWordMap::KeyWord::DOC_STRING:
          PARSE_TOKEN(strval);
          model.set_doc_string(strval);
          break;
        case KeyWordMap::KeyWord::METADATA_PROPS: {
          auto& metadata_props = *model.mutable_metadata_props();
          MATCH('[');
          if (!Matches(']')) {
            PARSE(metadata_props);
            MATCH(']');
          }
          break;
        }
        default:
          return ParseError("Unhandled keyword.");
      }
    } while (Matches(','));
    MATCH('>');
  }
  PARSE(*model.mutable_graph());

  auto* functions = model.mutable_functions();
  while (!EndOfInput()) {
    PARSE(*functions->Add());
  }
  return Status::OK();
}
const std::unordered_map<std::string, KeyWordMap::KeyWord>& KeyWordMap::Instance() {
  static KeyWordMap instance;
  return instance.map_;
}

const std::string& KeyWordMap::ToString(KeyWord kw) {
  static std::string undefined("undefined");
  for (const auto& pair : Instance()) {
    if (pair.second == kw)
      return pair.first;
  }
  return undefined;
}
} // namespace ONNX_NAMESPACE
