# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import os
import torch
import torch_nebula

from deepspeed.runtime.checkpoint_engine.checkpoint_engine import \
    CheckpointEngine
from deepspeed.utils import logger, log_dist
from deepspeed.nebula.constants import *


def _get_tag_from_path(path):
    return os.path.basename(os.path.dirname(path))


class NebulaCheckpointEngine(CheckpointEngine):

    def __init__(self, config_params=None):
        super().__init__(config_params)
        self.checkpoint = None
        self.tag_flag = None
        self.enable_nebula_load = config_params.enable_nebula_load
        self.nebula_load_path = config_params.load_path
        if self.nebula_load_path is None:
            self.nebula_load_path = config_params.persistent_storage_path

        nebula_config_params = {
            NEBULA_PERSISTENT_STORAGE_PATH: config_params.persistent_storage_path,
            NEBULA_PERSISTENT_TIME_INTERVAL: config_params.persistent_time_interval,
            NEBULA_NUM_OF_VERSION_IN_RETENTION: config_params.num_of_version_in_retention,
        }
        torch_nebula.init(**nebula_config_params)

    def create(self, tag):
        log_dist(f"[Nebula] Start Checkpoint for tag:{tag}", ranks=[0])
        # -2 means: customer needs to  explicitly tell nebula
        # current checkpoint is complete by commit method.
        self.checkpoint = torch_nebula.Checkpoint(tag, -2)

    def save(self, state_dict, path: str):
        log_dist(f"[Nebula] Create dummy files for loading.")
        torch.save("", path)

        tag = _get_tag_from_path(path)
        partition_name = os.path.basename(path)
        logger.info(f"[Nebula] Saving {partition_name} under tag {tag}...")
        self.checkpoint.save(partition_name, state_dict)
        logger.info(f"[Nebula] Saved {partition_name} under tag {tag}.")
        return None

    def load(self, path: str, map_location=None):
        tag = _get_tag_from_path(path)
        first_load_flag = self.tag_flag is None or self.tag_flag == tag
        if not self.enable_nebula_load and first_load_flag:
            self.tag_flag = tag
            logger.info(f"[Nebula] Disable nebula load. Loading checkpoint from {path} ...")
            partition = torch.load(path, map_location=map_location, weights_only=False)
            logger.info(f"[Nebula] Disable nebula load. Loaded checkpoint from {path} .")
            return partition

        partition_name = os.path.basename(path)
        logger.info(f"[Nebula] Loading {path} under tag {tag} from nebula path {self.nebula_load_path}...")

        checkpoint = None
        if tag in (None, 'latest', 'latest_universal'):
            # In some cases, there is the inconsistent tag between deepspeed metadata (latest file)
            # and nebula metadata, will lead to the failure on loading with deepspeed tag. Then we
            # will try to load the valid latest checkpoint from nebula(tier3 > tier1). So, in summary
            # when met failure loading for given tag, the loading priority would be like:
            #               nebula tier3 latest > nebula tier1 latest.
            checkpoint = torch_nebula.get_latest_checkpoint(persist_path=self.nebula_load_path)
        else:
            checkpoint = torch_nebula.get_checkpoint(tag=tag, persist_path=self.nebula_load_path)

        if checkpoint is None or (checkpoint is not None and checkpoint.tag == ''):
            logger.info(
                f"Unable to find valid checkpoint tag:{tag} from Nebula, try to get latest checkpoint again from nebula {self.nebula_load_path} path!"
            )
            # nebula tier3 latest
            checkpoint = torch_nebula.get_latest_checkpoint(persist_path=self.nebula_load_path)
            if checkpoint is None or (checkpoint is not None and checkpoint.tag == ''):
                logger.info(
                    f"Unable to find latest checkpoint from Nebula tier3, try to get latest checkpoint again from nebula tier1 path!"
                )
                # nebula tier1 latest
                checkpoint = torch_nebula.get_latest_checkpoint()
                logger.warning(f"Unable to find valid checkpoint from Nebula under tag:{tag}.")
                return None

        tag = checkpoint.tag
        self.tag_flag = -1
        partition = checkpoint.load(partition_name, map_location=map_location)
        logger.info(f"[Nebula] Loaded {path} under tag {tag} from {self.nebula_load_path}.")
        return partition

    def commit(self, tag):
        # nebula commit will be call when all files under give tag are ready to be persisted in the async way.
        logger.info(f"[Nebula] all files for {tag} are saved in tier1. It is ready to start persisting")
        commit_rls = self.checkpoint.commit()
        if not commit_rls:
            logger.error(f"[Nebula] failed to commit the checkpoint, please check the log.")
            return False
        return commit_rls
