// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#pragma once

#include <torch/types.h>
#include <memory>
#include <mutex>

namespace facebook::torchcodec {

// This header defines simple cache class primitives to store reusable objects
// across TorchCodec stream instances. Intended usage is to store hardware
// contexts creation of which is expensive. The cache mechanism is as follows:
// 1. 'PerGpuCache' provides a dynamic cache with the specified maximum capacity
//    for the given number of GPUs.
// 2. When stream object (e.g. SingleStreamDecoder) is destoyed cachable object
//    must be released to the cache. Cache will accept the object if it is not
//    full.
// 3. When stream object (e.g. SingleStreamDecoder) is created cachable object
//    must be first queried from the cache. If the cache is empty then new
//    object must be created.

template <typename T, typename D = std::default_delete<T>>
class Cache {
 public:
  using element_type = std::unique_ptr<T, D>;

  explicit Cache(int capacity) : capacity_(capacity) {}

  // Adds an object to the cache if the cache has capacity. Returns true
  // if object was added and false otherwise.
  bool addIfCacheHasCapacity(element_type&& obj);

  // Returns an object from the cache. Cache does not hold a reference
  // to the object after this call.
  element_type get();

 private:
  int capacity_;
  std::mutex mutex_;
  std::vector<element_type> cache_;
};

template <typename T, typename D>
bool Cache<T, D>::addIfCacheHasCapacity(element_type&& obj) {
  std::scoped_lock lock(mutex_);
  if (capacity_ >= 0 && cache_.size() >= static_cast<size_t>(capacity_)) {
    return false;
  }
  cache_.push_back(std::move(obj));
  return true;
}

template <typename T, typename D>
typename Cache<T, D>::element_type Cache<T, D>::get() {
  std::scoped_lock lock(mutex_);
  if (cache_.empty()) {
    return nullptr;
  }

  element_type obj = std::move(cache_.back());
  cache_.pop_back();
  return obj;
}

template <typename T, typename D = std::default_delete<T>>
class PerGpuCache {
 public:
  using element_type = typename Cache<T, D>::element_type;

  // Initializes 'maxGpus' number of caches. Each cache can hold no
  // more than 'capacity' items. If 'capacity' <0 cache size is unlimited.
  PerGpuCache(int maxGpus, int capacity) {
    TORCH_CHECK(maxGpus > 0, "maxGpus for PerGpuCache must be >0");
    for (int i = 0; i < maxGpus; ++i) {
      cache_.emplace_back(std::make_unique<Cache<T, D>>(capacity));
    }
  }

  // Adds an object to the specified device cache if the cache has
  // capacity. Returns true if object was added and false otherwise.
  bool addIfCacheHasCapacity(const torch::Device& device, element_type&& obj);

  // Returns an object from the cache of the specified device. Cache
  // does not hold a reference to the object after this call.
  element_type get(const torch::Device& device);

 private:
  // 'Cache' class implementation contains mutex which makes it non-movable
  // and non-copyable, so we need to wrap it in std::unique_ptr.
  std::vector<std::unique_ptr<Cache<T, D>>> cache_;
};

// Forward declaration of getDeviceIndex which exists in CUDACommon.h
// This avoids circular dependency between Cache.h and CUDACommon.cpp which also
// needs to include Cache.h
int getDeviceIndex(const torch::Device& device);

template <typename T, typename D>
bool PerGpuCache<T, D>::addIfCacheHasCapacity(
    const torch::Device& device,
    element_type&& obj) {
  int deviceIndex = getDeviceIndex(device);
  TORCH_CHECK(
      static_cast<size_t>(deviceIndex) < cache_.size(),
      "Device index out of range");
  return cache_[deviceIndex]->addIfCacheHasCapacity(std::move(obj));
}

template <typename T, typename D>
typename PerGpuCache<T, D>::element_type PerGpuCache<T, D>::get(
    const torch::Device& device) {
  int deviceIndex = getDeviceIndex(device);
  TORCH_CHECK(
      static_cast<size_t>(deviceIndex) < cache_.size(),
      "Device index out of range");
  return cache_[deviceIndex]->get();
}

} // namespace facebook::torchcodec
