# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from .constants import *
import copy
from ..config_utils import get_scalar_param


# TODO: Reducing config verbosity by returning None or {} when disabled.
# One challenge is that we still need to somehow include the default values,
# for example the *_ENABLED has default of false.
def get_data_efficiency_config(param_dict):
    output = {}
    output[DATA_EFFICIENCY_ENABLED] = get_data_efficiency_enabled(param_dict)
    output[DATA_EFFICIENCY_SEED] = get_data_efficiency_seed(param_dict)
    if DATA_EFFICIENCY not in param_dict.keys():
        param_dict[DATA_EFFICIENCY] = {}
    sub_param_dict = param_dict[DATA_EFFICIENCY]
    output[DATA_SAMPLING] = get_data_sampling(sub_param_dict)
    output[DATA_ROUTING] = get_data_routing(sub_param_dict)
    return output


def get_data_efficiency_enabled(param_dict):
    if DATA_EFFICIENCY in param_dict.keys():
        return get_scalar_param(param_dict[DATA_EFFICIENCY], DATA_EFFICIENCY_ENABLED, DATA_EFFICIENCY_ENABLED_DEFAULT)
    else:
        return False


def get_data_efficiency_seed(param_dict):
    if DATA_EFFICIENCY in param_dict.keys():
        return get_scalar_param(param_dict[DATA_EFFICIENCY], DATA_EFFICIENCY_SEED, DATA_EFFICIENCY_SEED_DEFAULT)
    else:
        return DATA_EFFICIENCY_SEED_DEFAULT


def get_data_sampling(param_dict):
    sub_param_dict = param_dict.get(DATA_SAMPLING, {})
    output = copy.copy(sub_param_dict)
    output[DATA_SAMPLING_ENABLED] = get_data_sampling_enabled(param_dict)
    output[DATA_SAMPLING_NUM_EPOCHS] = get_data_sampling_num_epochs(param_dict)
    output[DATA_SAMPLING_NUM_WORKERS] = get_data_sampling_num_workers(param_dict)
    output[DATA_SAMPLING_PIN_MEMORY] = get_data_sampling_pin_memory(param_dict)
    output[CURRICULUM_LEARNING] = get_curriculum_learning(sub_param_dict)
    output[DYNAMIC_BATCHING] = get_dynamic_batching(sub_param_dict)
    return output


def get_data_sampling_enabled(param_dict):
    if DATA_SAMPLING in param_dict.keys():
        return get_scalar_param(param_dict[DATA_SAMPLING], DATA_SAMPLING_ENABLED, DATA_SAMPLING_ENABLED_DEFAULT)
    else:
        return False


def get_data_sampling_num_epochs(param_dict):
    if DATA_SAMPLING in param_dict.keys():
        return get_scalar_param(param_dict[DATA_SAMPLING], DATA_SAMPLING_NUM_EPOCHS, DATA_SAMPLING_NUM_EPOCHS_DEFAULT)
    else:
        return DATA_SAMPLING_NUM_EPOCHS_DEFAULT


def get_data_sampling_num_workers(param_dict):
    if DATA_SAMPLING in param_dict.keys():
        return get_scalar_param(param_dict[DATA_SAMPLING], DATA_SAMPLING_NUM_WORKERS,
                                DATA_SAMPLING_NUM_WORKERS_DEFAULT)
    else:
        return DATA_SAMPLING_NUM_WORKERS_DEFAULT


def get_data_sampling_pin_memory(param_dict):
    if DATA_SAMPLING in param_dict.keys():
        return get_scalar_param(param_dict[DATA_SAMPLING], DATA_SAMPLING_PIN_MEMORY, DATA_SAMPLING_PIN_MEMORY_DEFAULT)
    else:
        return DATA_SAMPLING_PIN_MEMORY_DEFAULT


def get_curriculum_learning(param_dict):
    output = {}
    output[CURRICULUM_LEARNING_ENABLED] = get_curriculum_learning_enabled(param_dict)
    if CURRICULUM_LEARNING not in param_dict.keys():
        param_dict[CURRICULUM_LEARNING] = {}
    sub_param_dict = param_dict[CURRICULUM_LEARNING]
    if output[CURRICULUM_LEARNING_ENABLED]:
        assert CURRICULUM_LEARNING_METRICS in sub_param_dict.keys(
        ), f"Curriculum learning is enabled, {CURRICULUM_LEARNING_METRICS} must be specified"
        for key, val in get_curriculum_learning_params(param_dict).items():
            output[key] = val
    return output


def get_dynamic_batching(param_dict):
    output = copy.copy(param_dict.get(DYNAMIC_BATCHING, {}))
    output[DYNAMIC_BATCHING_ENABLED] = bool(output.get(DYNAMIC_BATCHING_ENABLED, DYNAMIC_BATCHING_ENABLED_DEFAULT))
    output[DYNAMIC_BATCHING_LR_SCALING_METHOD] = str(
        output.get(DYNAMIC_BATCHING_LR_SCALING_METHOD, DYNAMIC_BATCHING_LR_SCALING_METHOD_DEFAULT))
    output[DYNAMIC_BATCHING_MIN_BATCH_SIZE] = int(
        output.get(DYNAMIC_BATCHING_MIN_BATCH_SIZE, DYNAMIC_BATCHING_MIN_BATCH_SIZE_DEFAULT))
    output[DYNAMIC_BATCHING_MAX_BATCH_SIZE] = int(output[DYNAMIC_BATCHING_MAX_BATCH_SIZE]) \
        if DYNAMIC_BATCHING_MAX_BATCH_SIZE in output.keys() \
        else DYNAMIC_BATCHING_MAX_BATCH_SIZE_DEFAULT
    output[DYNAMIC_BATCHING_SEQUENCE_PICKING_ORDER] = str(
        output.get(DYNAMIC_BATCHING_SEQUENCE_PICKING_ORDER, DYNAMIC_BATCHING_SEQUENCE_PICKING_ORDER_DEFAULT))
    if output[DYNAMIC_BATCHING_ENABLED]:
        assert DYNAMIC_BATCHING_MAX_TOKENS in output.keys(
        ), f"Dynamic batching is enabled, so {DYNAMIC_BATCHING_MAX_TOKENS} must be specified"
        output[DYNAMIC_BATCHING_MAX_TOKENS] = int(output[DYNAMIC_BATCHING_MAX_TOKENS])
    output[DYNAMIC_BATCHING_VERBOSE] = bool(output.get(DYNAMIC_BATCHING_VERBOSE, False))
    return output


def get_curriculum_learning_enabled(param_dict):
    if CURRICULUM_LEARNING in param_dict.keys():
        return get_scalar_param(param_dict[CURRICULUM_LEARNING], CURRICULUM_LEARNING_ENABLED,
                                CURRICULUM_LEARNING_ENABLED_DEFAULT)
    else:
        return False


def get_curriculum_learning_params(param_dict):
    if CURRICULUM_LEARNING in param_dict.keys():
        curriculum_learning_params = copy.copy(param_dict[CURRICULUM_LEARNING])
        curriculum_learning_params.pop(CURRICULUM_LEARNING_ENABLED)
        return curriculum_learning_params
    else:
        return {}


def get_curriculum_enabled_legacy(param_dict):
    if CURRICULUM_LEARNING_LEGACY in param_dict.keys():
        return get_scalar_param(param_dict[CURRICULUM_LEARNING_LEGACY], CURRICULUM_ENABLED_LEGACY,
                                CURRICULUM_ENABLED_DEFAULT_LEGACY)
    else:
        return False


def get_curriculum_params_legacy(param_dict):
    if CURRICULUM_LEARNING_LEGACY in param_dict.keys():
        curriculum_params = copy.copy(param_dict[CURRICULUM_LEARNING_LEGACY])
        curriculum_params.pop(CURRICULUM_ENABLED_LEGACY)
        return curriculum_params
    else:
        return False


def get_data_routing(param_dict):
    output = {}
    output[DATA_ROUTING_ENABLED] = get_data_routing_enabled(param_dict)
    if DATA_ROUTING not in param_dict.keys():
        param_dict[DATA_ROUTING] = {}
    sub_param_dict = param_dict[DATA_ROUTING]
    output[RANDOM_LTD] = get_random_ltd(sub_param_dict)

    return output


def get_data_routing_enabled(param_dict):
    if DATA_ROUTING in param_dict.keys():
        return get_scalar_param(param_dict[DATA_ROUTING], DATA_ROUTING_ENABLED, DATA_ROUTING_ENABLED_DEFAULT)
    else:
        return False


def get_random_ltd(param_dict):
    output = {}
    output[RANDOM_LTD_ENABLED] = RANDOM_LTD_ENABLED_DEFAULT
    output[RANDOM_LTD_LAYER_TOKEN_LR_SCHEDULE] = {}
    output[RANDOM_LTD_LAYER_TOKEN_LR_SCHEDULE][
        RANDOM_LTD_LAYER_TOKEN_LR_ENABLED] = RANDOM_LTD_LAYER_TOKEN_LR_ENABLED_DEFAULT
    if get_random_ltd_enabled(param_dict):
        output[RANDOM_LTD_ENABLED] = get_random_ltd_enabled(param_dict)
        for key, val in get_random_ltd_params(param_dict).items():
            output[key] = val
    return output


def get_random_ltd_enabled(param_dict):
    if RANDOM_LTD in param_dict.keys():
        return get_scalar_param(param_dict[RANDOM_LTD], RANDOM_LTD_ENABLED, RANDOM_LTD_ENABLED_DEFAULT)
    else:
        return False


def get_random_ltd_params(param_dict):
    if RANDOM_LTD in param_dict.keys():
        random_ltd_params = copy.copy(param_dict[RANDOM_LTD])
        random_ltd_params.pop(RANDOM_LTD_ENABLED)
        return random_ltd_params
    else:
        return {}
