// Copyright (c) ONNX Project Contributors
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>

#include "onnx/defs/schema.h"
#include "onnx/defs/tensor_proto_util.h"
#include "onnx/string_utils.h"

namespace ONNX_NAMESPACE {
namespace shape_inference {

using ModelLocalFunctionsMap = std::unordered_map<std::string, const FunctionProto*>;

// We reuse TensorShapeProto to propagate statically known (partial) information about
// the values of tensors. It is intended for tensors used to store shape information
// (the return values of ops like Shape and input values of ops like Reshape/Expand).

// A DataValueMap is used to store the statically known (partial) values of variables.
using DataValueMap = std::unordered_map<std::string, TensorShapeProto>;

class SymbolTableImpl : public SymbolTable {
 public:
  SymbolTableImpl() = default;

  void addFromGraph(const GraphProto& g) override {
    AddExistingSymbolicDims(g.input());
    AddExistingSymbolicDims(g.output());
    AddExistingSymbolicDims(g.value_info());
  }
  // Creates a new unique symbol with the given prefix and adds it to the SymbolTable
  // Returns the newly created symbol
  std::string createNew(const std::string& symbol_prefix) override {
    std::string newSymbol;
    do {
      newSymbol = symbol_prefix + std::to_string(index_++);
    } while (existing_symbols.count(newSymbol) > 0);
    existing_symbols.insert(newSymbol);
    return newSymbol;
  }

 private:
  unsigned int index_{0};
  std::unordered_set<std::string> existing_symbols;

  // TypeProto_Tensor or TypeProto_SparseTensor
  template <typename TensorTypeProto>
  void AddExistingSymbolicDims(const TensorTypeProto& tensorType) {
    if (tensorType.has_shape()) {
      for (int i = 0; i < tensorType.shape().dim_size(); ++i) {
        if (tensorType.shape().dim(i).has_dim_param()) {
          existing_symbols.insert(tensorType.shape().dim(i).dim_param());
        }
      }
    }
  }

  void AddExistingSymbolicDims(const TypeProto& typeProto) {
    const auto val_case = typeProto.value_case();
    switch (val_case) {
      case TypeProto::kTensorType:
        AddExistingSymbolicDims(typeProto.tensor_type());
        break;
      case TypeProto::kSparseTensorType:
        AddExistingSymbolicDims(typeProto.sparse_tensor_type());
        break;
      case TypeProto::kSequenceType:
        AddExistingSymbolicDims(typeProto.sequence_type().elem_type());
        break;
      case TypeProto::kOptionalType:
        AddExistingSymbolicDims(typeProto.optional_type().elem_type());
        break;
      case TypeProto::kMapType:
        AddExistingSymbolicDims(typeProto.map_type().value_type());
        break;
      default:
        break;
    }
  }

  void AddExistingSymbolicDims(const google::protobuf::RepeatedPtrField<ValueInfoProto>& protos) {
    for (const auto& proto : protos) {
      AddExistingSymbolicDims(proto.type());
    }
  }
};

struct GraphInferenceContext {
  GraphInferenceContext(
      const std::unordered_map<std::string, TypeProto*>& outer_scope_value_types_by_name_in,
      std::unordered_map<std::string, int> opset_imports_in,
      SymbolTable* symbol_table_in = nullptr,
      const ModelLocalFunctionsMap& model_local_functions_in = {},
      const ISchemaRegistry* schema_registry_in = OpSchemaRegistry::Instance(),
      DataValueMap* generated_shape_data_by_name_in = nullptr,
      const int ir_version_in = IR_VERSION)
      : outer_scope_value_types_by_name{&outer_scope_value_types_by_name_in},
        opset_imports{std::move(opset_imports_in)},
        symbol_table{symbol_table_in},
        model_local_functions{model_local_functions_in},
        schema_registry{schema_registry_in},
        generated_shape_data_by_name{generated_shape_data_by_name_in},
        ir_version{ir_version_in} {}

  const std::unordered_map<std::string, TypeProto*>* outer_scope_value_types_by_name;
  const std::unordered_map<std::string, int> opset_imports;
  SymbolTable* symbol_table;
  const ModelLocalFunctionsMap& model_local_functions;
  const ISchemaRegistry* schema_registry;
  DataValueMap* generated_shape_data_by_name;
  const int ir_version;
};

class GraphInferencerImpl : public GraphInferencer {
 public:
  GraphInferencerImpl(GraphProto& g, GraphInferenceContext& context) : g_{&g}, context_{&context} {}
  GraphInferencerImpl(GraphProto& g, GraphInferenceContext& context, const ShapeInferenceOptions& options)
      : g_{&g}, context_{&context}, options_(options) {}

  std::vector<const TypeProto*> doInferencing(
      const std::vector<const TypeProto*>& inputTypes,
      const std::vector<const TensorProto*>& inputData) override;

 private:
  GraphProto* g_;
  GraphInferenceContext* context_;
  ShapeInferenceOptions options_;
};

struct InferenceContextImpl : public InferenceContext {
  InferenceContextImpl(
      NodeProto& n,
      const std::unordered_map<std::string, TypeProto*>& valueTypesByName,
      const std::unordered_map<std::string, const TensorProto*>& inputDataByName,
      const std::unordered_map<std::string, const SparseTensorProto*>& inputSparseDataByName,
      const ShapeInferenceOptions& options,
      DataValueMap* generatedShapeData = nullptr,
      GraphInferenceContext* graphInferenceContext = nullptr)
      : graphInferenceContext_{graphInferenceContext}, options_(options), node_(&n) {
    for (auto& attr : *n.mutable_attribute()) {
      attributesByName_[attr.name()] = &attr;
      if (attr.has_g()) {
        // need a mutable GraphProto to run inferencing on this attribute
        graphProtoAttributesByName_[attr.name()] = attr.mutable_g();
      }
    }

    for (const auto& input : n.input()) {
      auto valueTypesIter = valueTypesByName.find(input);
      if (valueTypesIter != valueTypesByName.end()) {
        allInputTypes_.push_back(valueTypesIter->second);
      } else {
        allInputTypes_.push_back(nullptr);
      }

      // input data can be in 1 of the 3 containers
      // inputDataByName - this is when input is TensorProto
      // inputSparseDataByName - this is when input is SparseTensorProto
      // generatedShapeData - this is when input was generated as part of partial data propagation
      const auto inputDataIter = inputDataByName.find(input);
      if (inputDataIter != inputDataByName.cend()) {
        allInputData_.push_back(inputDataIter->second);
        allInputSparseData_.push_back(nullptr);
        allShapeInputData_.push_back(nullptr);
      } else {
        allInputData_.push_back(nullptr);
        const auto inputSparseDataIter = inputSparseDataByName.find(input);
        if (inputSparseDataIter != inputSparseDataByName.cend()) {
          allInputSparseData_.push_back(inputSparseDataIter->second);
          allShapeInputData_.push_back(nullptr);
        } else {
          allInputSparseData_.push_back(nullptr);
          if (generatedShapeData != nullptr) {
            const auto inputShapeDataIter = generatedShapeData->find(input);
            if (inputShapeDataIter != generatedShapeData->cend()) {
              allShapeInputData_.push_back(&inputShapeDataIter->second);
            } else {
              allShapeInputData_.push_back(nullptr);
            }
          } else {
            allShapeInputData_.push_back(nullptr);
          }
        }
      }
    }

    allOutputTypes_.resize(n.output_size());
  }

  const AttributeProto* getAttribute(const std::string& name) const override {
    auto iter = attributesByName_.find(name);
    if (iter == attributesByName_.end()) {
      return nullptr;
    } else {
      return iter->second;
    }
  }

  size_t getNumInputs() const override {
    return allInputTypes_.size();
  }

  const TypeProto* getInputType(size_t index) const override {
    if (index >= allInputTypes_.size()) {
      ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
    }
    return allInputTypes_[index];
  }

  const TensorProto* getInputData(size_t index) const override {
    if (index >= allInputData_.size()) {
      ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
    }
    return allInputData_[index];
  }

  const TensorShapeProto* getSymbolicInput(size_t index) const override {
    if (index >= allShapeInputData_.size()) {
      ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
    }

    return allShapeInputData_[index];
  }

  const SparseTensorProto* getInputSparseData(size_t index) const override {
    if (index >= allInputSparseData_.size()) {
      ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
    }
    return allInputSparseData_[index];
  }

  size_t getNumOutputs() const override {
    return allOutputTypes_.size();
  }

  TypeProto* getOutputType(size_t index) override {
    if (index >= allOutputTypes_.size()) {
      ONNX_THROW("Output " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
    }
    return &allOutputTypes_[index];
  }

  GraphInferencer* getGraphAttributeInferencer(const std::string& attr_name) override {
    if (!graphInferenceContext_) {
      fail_type_inference("GraphProto attribute inferencing is not enabled in this InferenceContextImpl instance.");
    }

    GraphInferencer* inferencer = nullptr;

    auto entry = graphAttributeInferencers_.find(attr_name);
    if (entry == graphAttributeInferencers_.cend()) {
      // create GraphInferencer instance
      auto attrNameToGraphProto = graphProtoAttributesByName_.find(attr_name);
      if (attrNameToGraphProto == graphProtoAttributesByName_.cend()) {
        fail_type_inference("Attribute ", attr_name, " does not contain a graph.");
      }

      auto new_inferencer =
          std::make_unique<GraphInferencerImpl>(*attrNameToGraphProto->second, *graphInferenceContext_, options_);
      inferencer = new_inferencer.get();
      graphAttributeInferencers_.emplace(attr_name, std::move(new_inferencer));
    } else {
      inferencer = entry->second.get();
    }

    return inferencer;
  }

  std::string getDisplayName() const override {
    if (node_ == nullptr)
      return "";
    if (node_->domain().empty()) {
      if (node_->name().empty())
        return MakeString("node ", node_->op_type());
      return MakeString("node ", node_->op_type(), " (", node_->name(), ")");
    }
    if (node_->name().empty())
      return MakeString("node ", node_->op_type(), "[", node_->domain(), "]");
    return MakeString("node ", node_->op_type(), "[", node_->domain(), "]", " (", node_->name(), ")");
  }

  std::vector<const TensorProto*> allInputData_;
  std::vector<const SparseTensorProto*> allInputSparseData_;
  std::vector<const TensorShapeProto*> allShapeInputData_;
  std::unordered_map<std::string, const AttributeProto*> attributesByName_;
  std::unordered_map<std::string, GraphProto*> graphProtoAttributesByName_;
  std::vector<const TypeProto*> allInputTypes_;
  std::vector<TypeProto> allOutputTypes_;
  GraphInferenceContext* graphInferenceContext_;

  // mutable as internal cache of GraphInferencer instances
  mutable std::unordered_map<std::string, std::unique_ptr<GraphInferencer>> graphAttributeInferencers_;
  ShapeInferenceOptions options_;
  NodeProto* node_;
};

struct DataPropagationContextImpl : public DataPropagationContext {
  DataPropagationContextImpl(
      NodeProto& n,
      const std::unordered_map<std::string, TypeProto*>& valueTypesByName,
      const std::unordered_map<std::string, const TensorProto*>& inputDataByName,
      DataValueMap& generatedShapeData)
      : generatedShapeData_(generatedShapeData) {
    size_t input_idx = 0;

    for (auto& attr : *n.mutable_attribute()) {
      attributesByName_[attr.name()] = &attr;
    }

    for (const auto& input : n.input()) {
      inputIndexToNameMap_.insert({input_idx++, input});

      auto valueTypesIter = valueTypesByName.find(input);
      if (valueTypesIter != valueTypesByName.end()) {
        allInputTypes_.push_back(valueTypesIter->second);
      } else {
        allInputTypes_.push_back(nullptr);
      }

      const auto inputDataIter = inputDataByName.find(input);
      if (inputDataIter != inputDataByName.cend()) {
        allInputData_.push_back(inputDataIter->second);
      } else {
        allInputData_.push_back(nullptr);
      }
    }

    size_t output_idx = 0;
    for (const auto& output : n.output()) {
      outputIndexToNameMap_.insert({output_idx++, output});
    }

    allOutputTypes_.resize(n.output_size());
  }

  const AttributeProto* getAttribute(const std::string& name) const override {
    auto iter = attributesByName_.find(name);
    if (iter == attributesByName_.end()) {
      return nullptr;
    } else {
      return iter->second;
    }
  }

  size_t getNumInputs() const override {
    return allInputTypes_.size();
  }

  const TypeProto* getInputType(size_t index) const override {
    if (index >= allInputTypes_.size()) {
      ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
    }
    return allInputTypes_[index];
  }

  size_t getNumOutputs() const override {
    return allOutputTypes_.size();
  }

  const TypeProto* getOutputType(size_t index) const override {
    if (index >= allOutputTypes_.size()) {
      ONNX_THROW("Output " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
    }
    return &allOutputTypes_[index];
  }

  // Convert integer vector into TensorShapeProto
  template <typename INTEGER>
  void vectorToTensorShapeProto(const std::vector<INTEGER>& input_vals, TensorShapeProto& converted_tsp) const {
    for (unsigned int i = 0; i < input_vals.size(); ++i) {
      converted_tsp.mutable_dim()->Add()->set_dim_value(input_vals[i]);
    }
  }

  const TensorShapeProto* getInputData(size_t index) override {
    if (index >= allInputData_.size()) {
      ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
    }
    const std::string input_name = inputIndexToNameMap_.at(index);
    // Gets it from previous data propagation
    auto iter = generatedShapeData_.find(input_name);
    if (iter != generatedShapeData_.end()) {
      return &iter->second;
    }
    // Otherwise, gets it from initializer if it exists
    const auto* input_data = allInputData_[index];
    // Only scalar (0D tensor) or 1D tensor can be converted for now
    // TODO: It should support tensors with more dimension on demand
    if (input_data != nullptr && (input_data->dims_size() == 0 || input_data->dims_size() == 1)) {
      TensorShapeProto tsp;

      if (input_data->data_type() == TensorProto_DataType_INT64) {
        vectorToTensorShapeProto(ParseData<int64_t>(input_data), tsp);
      } else if (input_data->data_type() == TensorProto_DataType_INT32) {
        vectorToTensorShapeProto(ParseData<int32_t>(input_data), tsp);
      } else {
        // Only supports integer type to form a shape
        return nullptr;
      }

      // Adds this TensorShapeProto from initializer into generatedShapeData
      // for future use
      auto result = generatedShapeData_.insert({input_name, std::move(tsp)});
      if (result.second) {
        return &(result.first->second);
      }
    }

    // If X has a known rank N, then X's value can be represented as an array
    // of N unknown values (represented as a TensorShapeProto).
    const TypeProto* type = getInputType(index);
    if ((type != nullptr) && (type->has_tensor_type())) {
      auto& tensor_type = type->tensor_type();
      if (tensor_type.has_shape()) {
        const TensorShapeProto& shape = tensor_type.shape();
        if ((shape.dim_size() == 1) && (shape.dim(0).has_dim_value())) {
          TensorShapeProto tsp;
          int64_t dim_value = shape.dim(0).dim_value();
          for (int64_t i = 0; i < dim_value; ++i) {
            tsp.add_dim();
          }
          auto result = generatedShapeData_.insert({input_name, std::move(tsp)});
          if (result.second) {
            return &(result.first->second);
          }
        }
      }
    }
    return nullptr;
  }

  void addOutputData(size_t index, TensorShapeProto&& tsp) override {
    if (index >= outputIndexToNameMap_.size()) {
      ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
    }
    auto result = generatedShapeData_.insert({outputIndexToNameMap_.at(index), std::move(tsp)});
    if (!result.second) {
      fail_shape_inference("Data for input  " + ONNX_NAMESPACE::to_string(index) + " already exists.");
    }
  }

  std::vector<const TensorProto*> allInputData_;
  std::unordered_map<size_t, std::string> inputIndexToNameMap_;
  std::unordered_map<size_t, std::string> outputIndexToNameMap_;
  std::vector<const TypeProto*> allInputTypes_;
  std::vector<TypeProto> allOutputTypes_;
  DataValueMap& generatedShapeData_;
  std::unordered_map<std::string, const AttributeProto*> attributesByName_;
};

void checkShapesAndTypes(const TypeProto_Sequence& inferredType, const TypeProto_Sequence& existingType);

void checkShapesAndTypes(const TypeProto& inferredType, const TypeProto& existingType);

template <typename TensorTypeProto>
void GenerateSymbolicShape(TensorTypeProto* inferredType, SymbolTable& symbolTable);

void MaterializeSymbolicShape(TypeProto* inferredType, SymbolTable& symbolTable);

void mergeShapesAndTypes(const TypeProto_Tensor& inferredType, TypeProto_Tensor* existingType);

void mergeShapesAndTypes(const TypeProto_SparseTensor& inferredType, TypeProto_SparseTensor* existingType);

void mergeShapesAndTypes(const TypeProto_Sequence& inferredType, TypeProto_Tensor* existingType);

void mergeShapesAndTypes(const TypeProto& inferredType, TypeProto* existingType);

///
/// ModelLocalFunctionsMap is a map of function id -> model local function proto
/// All the ONNX helper utilities expect the function id == <function_proto.domain>:<function_proto.name>
///
ONNX_API void InferShapes(
    GraphProto* g,
    const std::unordered_map<std::string, int>& opset_imports,
    const ISchemaRegistry* schema_registry = OpSchemaRegistry::Instance(),
    const ShapeInferenceOptions& options = ShapeInferenceOptions(),
    const ModelLocalFunctionsMap& in_model_functions = {});

ONNX_API void InferShapes(
    ModelProto& m,
    const ISchemaRegistry* schema_registry = OpSchemaRegistry::Instance(),
    const ShapeInferenceOptions& options = ShapeInferenceOptions(),
    DataValueMap* generated_shape_data_by_name = nullptr);

ONNX_API void InferShapes(
    const std::string& model_path,
    const std::string& save_path = "",
    const ISchemaRegistry* schema_registry = OpSchemaRegistry::Instance(),
    const ShapeInferenceOptions& options = ShapeInferenceOptions(),
    DataValueMap* generated_shape_data_by_name = nullptr);

///
/// ModelLocalFunctionsMap is a map of function id -> model local function proto
/// All the ONNX helper utilities expect the function id == <function_proto.domain>:<function_proto.name>
///
void InferShapeForFunctionNode(
    const FunctionProto& func,
    const ISchemaRegistry* schema_registry,
    InferenceContext& ctx,
    const ShapeInferenceOptions& options = ShapeInferenceOptions(),
    const ModelLocalFunctionsMap& model_local_functions_map = {},
    SymbolTable* symbolTable = nullptr,
    DataValueMap* generated_shape_data_by_name = nullptr);

///
/// ModelLocalFunctionsMap is a map of function id -> model local function proto
/// All the ONNX helper utilities expect the function id == <function_proto.domain>:<function_proto.name>
///
void InferShapeForFunctionNode(
    const FunctionProto& func_proto,
    const std::unordered_map<std::string, int>& func_opset_imports,
    const ISchemaRegistry* schema_registry,
    InferenceContext& ctx,
    const ShapeInferenceOptions& options = ShapeInferenceOptions(),
    const ModelLocalFunctionsMap& model_local_functions_map = {},
    SymbolTable* symbolTable = nullptr,
    DataValueMap* generated_shape_data_by_name = nullptr);

///
/// Apply type-and-shape-inference based checks to a Function body.
/// Returns the inferred types of the outputs of the function.
/// Inference depends on the types of the inputs of the function as well as
/// the attribute values supplied.
/// A TypeProto with value_case() == TypeProto::ValueCase::VALUE_NOT_SET is used
/// for missing optional parameters.
///
std::vector<TypeProto> InferFunctionOutputTypes(
    const FunctionProto& func_proto,
    const std::vector<TypeProto>& input_types,
    const std::vector<AttributeProto>& attributes);

std::string GetErrorWithNodeInfo(const NodeProto& n, const std::runtime_error& err);

void TraverseGraphsToAddExistingSymbols(const GraphProto& g, SymbolTable& symbolTable);

} // namespace shape_inference
} // namespace ONNX_NAMESPACE
