// Copyright (c) ONNX Project Contributors

/*
 * SPDX-License-Identifier: Apache-2.0
 */

// Interface for Op Version Adapters

#pragma once

#include <functional>
#include <memory>
#include <string>
#include <utility>

#include "onnx/onnx_pb.h"
#include "onnx/version_converter/helper.h"

namespace ONNX_NAMESPACE {
namespace version_conversion {

class Adapter {
 private:
  std::string name_;
  OpSetID initial_version_;
  OpSetID target_version_;

 public:
  virtual ~Adapter() noexcept = default;

  Adapter(std::string name, OpSetID initial_version, OpSetID target_version)
      : name_(std::move(name)),
        initial_version_(std::move(initial_version)),
        target_version_(std::move(target_version)) {}

  // This will almost always return its own node argument after modifying it in place.
  // The only exception are adapters for deprecated operators: in this case the input
  // node must be destroyed and a new one must be created and returned. See e.g.
  // upsample_9_10.h
  virtual Node* adapt(std::shared_ptr<Graph> /*graph*/, Node* node) const = 0;

  const std::string& name() const {
    return name_;
  }

  const OpSetID& initial_version() const {
    return initial_version_;
  }

  const OpSetID& target_version() const {
    return target_version_;
  }
};

using NodeTransformerFunction = std::function<Node*(std::shared_ptr<Graph>, Node* node)>;

class GenericAdapter final : public Adapter {
 public:
  GenericAdapter(const char* op, int64_t from, int64_t to, NodeTransformerFunction transformer)
      : Adapter(op, OpSetID(from), OpSetID(to)), transformer_(std::move(transformer)) {}

  Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
    return transformer_(graph, node);
  }

 private:
  NodeTransformerFunction transformer_;
};

} // namespace version_conversion
} // namespace ONNX_NAMESPACE
