// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0

// DeepSpeed Team

#include "deepcompile.h"

#define USE_C10D_NCCL

namespace dc {

std::shared_ptr<DSParamRegistry> param_registry;
std::unordered_map<long, std::shared_ptr<CustomOpExecutor>> executors;
std::shared_ptr<DoubleBufferedReduceBucket> reduce_buckets = nullptr;

c10::intrusive_ptr<c10d::ProcessGroup> process_group = nullptr;
c10::intrusive_ptr<c10d::symmetric_memory::SymmetricMemory> symm_mem = nullptr;
ncclComm_t nccl_comm;
bool use_symm_mem;
bool clone_custom_op_output;
bool profile = false;
bool pre_div_reduce = true;

bool sync_before_reduce;     // for debugging
bool sync_after_reduce;      // for debugging
bool sync_before_allgather;  // for debugging
bool sync_after_allgather;   // for debugging

std::vector<int64_t> sizes_to_int_vector(at::IntArrayRef sizes)
{
    std::vector<int64_t> result;
    for (int i = 0; i < sizes.size(); i++) { result.push_back(sizes[i]); }
    return result;
}

void enable_profiling(bool enable) { profile = enable; }

bool is_profiling() { return profile; }

c10::intrusive_ptr<c10d::symmetric_memory::SymmetricMemory> getSymmMemWorkspace(int64_t size)
{
    c10::Device device = c10::Device(c10::kCUDA, c10::cuda::current_device());
    std::vector<int64_t> sizes = {size};
    std::vector<int64_t> strides = {1};
    at::Tensor sym_mem_ws = c10d::symmetric_memory::empty_strided_p2p(
        {size}, {1}, c10::ScalarType::Byte, device, process_group->getGroupName(), std::nullopt);
    return c10d::symmetric_memory::rendezvous(sym_mem_ws);
}

void lazy_init_symm_memory()
{
    if (use_symm_mem && !symm_mem) {
        int64_t max_param_size = 0;
        for (const auto& it : param_registry->getParams()) {
            int64_t size = it.second.getDSTensor().numel() * it.second.getDSTensor().element_size();
            if (size > max_param_size) { max_param_size = size; }
        }
        symm_mem = getSymmMemWorkspace(max_param_size);
    }
}

ncclDataType_t get_nccl_data_type(at::ScalarType scalar_type)
{
    switch (scalar_type) {
        case at::kFloat: return ncclFloat;
        case at::kHalf: return ncclHalf;
        case at::kDouble: return ncclDouble;
        case at::kBFloat16: return ncclBfloat16;
        case at::kLong: return ncclInt64;
        case at::kInt: return ncclInt;
        case at::kChar: return ncclInt8;
        default: throw std::runtime_error("Unsupported scalar type");
    }
}

void reset()
{
    executors.clear();
    // We keep the buckets for memory estimation
    // reduce_buckets->clear();
}

void cleanup()
{
    reset();

    ncclCommDestroy(nccl_comm);
    process_group = nullptr;
    symm_mem = nullptr;
}

at::Tensor reduce_grad(at::Tensor grad_tensor, long graph_id, long ds_id)
{
    if (sync_before_reduce) { c10::cuda::device_synchronize(); }

    assert(hasKey(executors, graph_id));
    if (!profile) { executors[graph_id]->reduceGrad(grad_tensor, ds_id); }

    if (sync_after_reduce) { c10::cuda::device_synchronize(); }

    return at::Tensor();
}

at::Tensor reduce_grad_meta(at::Tensor grad_tensor, long graph_id, long ds_id)
{
    return at::Tensor();
}

void free_tensors(std::vector<at::Tensor> tensors)
{
    int64_t THRESHOLD = 10 * 1024 * 1024;

    if (!profile) {
        for (auto& tensor : tensors) {
            if (tensor.is_cuda() && tensor.numel() > THRESHOLD) {
                tensor.record_stream(at::cuda::getCurrentCUDAStream());
                tensor.set_data(torch::empty({0}, tensor.options()));
            }
        }
    }
}

void free_tensors_meta(std::vector<at::Tensor> tensors) {}

void init(c10::intrusive_ptr<c10d::ProcessGroup> pg,
          int64_t initial_reduce_bucket_size,
          bool enable_double_buffer,
          bool _use_symm_mem,
          bool _clone_custom_op_output,
          bool _sync_before_reduce,
          bool _sync_after_reduce,
          bool _sync_before_allgather,
          bool _sync_after_allgather)
{
    process_group = pg;

    ncclUniqueId ncclID;
    ncclGetUniqueId(&ncclID);

    // ProcessGroup doesn't have an API to get the CUDA stream for comm calls.
    // So we create a NCCL communicator and call NCCL APIs directly.
    auto vec = std::vector<uint8_t>(reinterpret_cast<uint8_t*>(&ncclID),
                                    reinterpret_cast<uint8_t*>(&ncclID) + NCCL_UNIQUE_ID_BYTES);
    auto device = torch::Device(torch::kCUDA);
    at::Tensor tensor = torch::from_blob(vec.data(), {static_cast<long>(vec.size())}, torch::kUInt8)
                            .to(torch::Device(torch::kCUDA));
    std::vector<at::Tensor> bcast_input = {tensor};

    process_group->broadcast(bcast_input, c10d::BroadcastOptions())->wait();

    // create a new nccl communicator
    std::memcpy(&ncclID, tensor.to(torch::Device(torch::kCPU)).data_ptr(), NCCL_UNIQUE_ID_BYTES);
    ncclCommInitRank(&nccl_comm, process_group->getSize(), ncclID, process_group->getRank());

    param_registry = std::make_shared<DSParamRegistry>();
    reduce_buckets = std::make_shared<DoubleBufferedReduceBucket>(initial_reduce_bucket_size,
                                                                  enable_double_buffer);
    use_symm_mem = _use_symm_mem;
    clone_custom_op_output = _clone_custom_op_output;

    sync_before_reduce = _sync_before_reduce;
    sync_after_reduce = _sync_after_reduce;
    sync_before_allgather = _sync_before_allgather;
    sync_after_allgather = _sync_after_allgather;
}

void start_forward()
{
    lazy_init_symm_memory();
    for (auto& it : executors) { it.second->startForward(); }
}

void end_forward()
{
    for (auto& it : executors) { it.second->endForward(); }
}

void start_backward(bool update)
{
    for (auto& it : executors) { it.second->startBackward(update); }
}

// We don't call this
// void end_backward(bool update)
// {
// }

}  // namespace dc
