# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import inspect
from dataclasses import is_dataclass
from typing import Dict, List, Optional

from nemo.utils import logging

# TODO @blisc: Perhaps refactor instead of import guarding
_HAS_HYDRA = True
try:
    from omegaconf import DictConfig, OmegaConf, open_dict
except ModuleNotFoundError:
    _HAS_HYDRA = False


def update_model_config(
    model_cls: 'nemo.core.config.modelPT.NemoConfig', update_cfg: 'DictConfig', drop_missing_subconfigs: bool = True
):
    """
    Helper class that updates the default values of a ModelPT config class with the values
    in a DictConfig that mirrors the structure of the config class.

    Assumes the `update_cfg` is a DictConfig (either generated manually, via hydra or instantiated via yaml/model.cfg).
    This update_cfg is then used to override the default values preset inside the ModelPT config class.

    If `drop_missing_subconfigs` is set, the certain sub-configs of the ModelPT config class will be removed, if
    they are not found in the mirrored `update_cfg`. The following sub-configs are subject to potential removal:
        -   `train_ds`
        -   `validation_ds`
        -   `test_ds`
        -   `optim` + nested `sched`.

    Args:
        model_cls: A subclass of NemoConfig, that details in entirety all of the parameters that constitute
            the NeMo Model.

        update_cfg: A DictConfig that mirrors the structure of the NemoConfig data class. Used to update the
            default values of the config class.

        drop_missing_subconfigs: Bool which determins whether to drop certain sub-configs from the NemoConfig
            class, if the corresponding sub-config is missing from `update_cfg`.

    Returns:
        A DictConfig with updated values that can be used to instantiate the NeMo Model along with supporting
        infrastructure.
    """
    if not _HAS_HYDRA:
        logging.error("This function requires Hydra/Omegaconf and it was not installed.")
        exit(1)
    if not (is_dataclass(model_cls) or isinstance(model_cls, DictConfig)):
        raise ValueError("`model_cfg` must be a dataclass or a structured OmegaConf object")

    if not isinstance(update_cfg, DictConfig):
        update_cfg = OmegaConf.create(update_cfg)

    if is_dataclass(model_cls):
        model_cls = OmegaConf.structured(model_cls)

    # Update optional configs
    model_cls = _update_subconfig(
        model_cls, update_cfg, subconfig_key='train_ds', drop_missing_subconfigs=drop_missing_subconfigs
    )
    model_cls = _update_subconfig(
        model_cls, update_cfg, subconfig_key='validation_ds', drop_missing_subconfigs=drop_missing_subconfigs
    )
    model_cls = _update_subconfig(
        model_cls, update_cfg, subconfig_key='test_ds', drop_missing_subconfigs=drop_missing_subconfigs
    )
    model_cls = _update_subconfig(
        model_cls, update_cfg, subconfig_key='optim', drop_missing_subconfigs=drop_missing_subconfigs
    )

    # Add optim and sched additional keys to model cls
    model_cls = _add_subconfig_keys(model_cls, update_cfg, subconfig_key='optim')

    # Perform full merge of model config class and update config
    # Remove ModelPT artifact `target`
    if 'target' in update_cfg.model:
        # Assume artifact from ModelPT and pop
        if 'target' not in model_cls.model:
            with open_dict(update_cfg.model):
                update_cfg.model.pop('target')

    # Remove ModelPT artifact `nemo_version`
    if 'nemo_version' in update_cfg.model:
        # Assume artifact from ModelPT and pop
        if 'nemo_version' not in model_cls.model:
            with open_dict(update_cfg.model):
                update_cfg.model.pop('nemo_version')

    model_cfg = OmegaConf.merge(model_cls, update_cfg)

    return model_cfg


def _update_subconfig(
    model_cfg: 'DictConfig', update_cfg: 'DictConfig', subconfig_key: str, drop_missing_subconfigs: bool
):
    """
    Updates the NemoConfig DictConfig such that:
    1)  If the sub-config key exists in the `update_cfg`, but does not exist in ModelPT config:
        - Add the sub-config from update_cfg to ModelPT config

    2) If the sub-config key does not exist in `update_cfg`, but exists in ModelPT config:
        - Remove the sub-config from the ModelPT config; iff the `drop_missing_subconfigs` flag is set.

    Args:
        model_cfg: A DictConfig instantiated from the NemoConfig subclass.
        update_cfg: A DictConfig that mirrors the structure of `model_cfg`, used to update its default values.
        subconfig_key: A str key used to check and update the sub-config.
        drop_missing_subconfigs: A bool flag, whether to allow deletion of the NemoConfig sub-config,
            if its mirror sub-config does not exist in the `update_cfg`.

    Returns:
        The updated DictConfig for the NemoConfig
    """
    if not _HAS_HYDRA:
        logging.error("This function requires Hydra/Omegaconf and it was not installed.")
        exit(1)
    with open_dict(model_cfg.model):
        # If update config has the key, but model cfg doesnt have the key
        # Add the update cfg subconfig to the model cfg
        if subconfig_key in update_cfg.model and subconfig_key not in model_cfg.model:
            model_cfg.model[subconfig_key] = update_cfg.model[subconfig_key]

        # If update config does not the key, but model cfg has the key
        # Remove the model cfg subconfig in order to match layout of update cfg
        if subconfig_key not in update_cfg.model and subconfig_key in model_cfg.model:
            if drop_missing_subconfigs:
                model_cfg.model.pop(subconfig_key)

    return model_cfg


def _add_subconfig_keys(model_cfg: 'DictConfig', update_cfg: 'DictConfig', subconfig_key: str):
    """
    For certain sub-configs, the default values specified by the NemoConfig class is insufficient.
    In order to support every potential value in the merge between the `update_cfg`, it would require
    explicit definition of all possible cases.

    An example of such a case is Optimizers, and their equivalent Schedulers. All optimizers share a few basic
    details - such as name and lr, but almost all require additional parameters - such as weight decay.
    It is impractical to create a config for every single optimizer + every single scheduler combination.

    In such a case, we perform a dual merge. The Optim and Sched Dataclass contain the bare minimum essential
    components. The extra values are provided via update_cfg.

    In order to enable the merge, we first need to update the update sub-config to incorporate the keys,
    with dummy temporary values (merge update config with model config). This is done on a copy of the
    update sub-config, as the actual override values might be overriden by the NemoConfig defaults.

    Then we perform a merge of this temporary sub-config with the actual override config in a later step
    (merge model_cfg with original update_cfg, done outside this function).

    Args:
        model_cfg: A DictConfig instantiated from the NemoConfig subclass.
        update_cfg: A DictConfig that mirrors the structure of `model_cfg`, used to update its default values.
        subconfig_key: A str key used to check and update the sub-config.

    Returns:
        A ModelPT DictConfig with additional keys added to the sub-config.
    """
    if not _HAS_HYDRA:
        logging.error("This function requires Hydra/Omegaconf and it was not installed.")
        exit(1)
    with open_dict(model_cfg.model):
        # Create copy of original model sub config
        if subconfig_key in update_cfg.model:
            if subconfig_key not in model_cfg.model:
                # create the key as a placeholder
                model_cfg.model[subconfig_key] = None

            subconfig = copy.deepcopy(model_cfg.model[subconfig_key])
            update_subconfig = copy.deepcopy(update_cfg.model[subconfig_key])

            # Add the keys and update temporary values, will be updated during full merge
            subconfig = OmegaConf.merge(update_subconfig, subconfig)
            # Update sub config
            model_cfg.model[subconfig_key] = subconfig

    return model_cfg


def assert_dataclass_signature_match(
    cls: 'class_type',
    datacls: 'dataclass',
    ignore_args: Optional[List[str]] = None,
    remap_args: Optional[Dict[str, str]] = None,
):
    """
    Analyses the signature of a provided class and its respective data class,
    asserting that the dataclass signature matches the class __init__ signature.

    Note:
        This is not a value based check. This function only checks if all argument
        names exist on both class and dataclass and logs mismatches.

    Args:
        cls: Any class type - but not an instance of a class. Pass type(x) where x is an instance
            if class type is not easily available.
        datacls: A corresponding dataclass for the above class.
        ignore_args: (Optional) A list of string argument names which are forcibly ignored,
            even if mismatched in the signature. Useful when a dataclass is a superset of the
            arguments of a class.
        remap_args: (Optional) A dictionary, mapping an argument name that exists (in either the
            class or its dataclass), to another name. Useful when argument names are mismatched between
            a class and its dataclass due to indirect instantiation via a helper method.

    Returns:
        A tuple containing information about the analysis:
        1) A bool value which is True if the signatures matched exactly / after ignoring values.
            False otherwise.
        2) A set of arguments names that exist in the class, but *do not* exist in the dataclass.
            If exact signature match occurs, this will be None instead.
        3) A set of argument names that exist in the data class, but *do not* exist in the class itself.
            If exact signature match occurs, this will be None instead.
    """
    class_sig = inspect.signature(cls.__init__)

    class_params = dict(**class_sig.parameters)
    class_params.pop('self')

    dataclass_sig = inspect.signature(datacls)

    dataclass_params = dict(**dataclass_sig.parameters)
    dataclass_params.pop("_target_", None)

    class_params = set(class_params.keys())
    dataclass_params = set(dataclass_params.keys())

    if remap_args is not None:
        for original_arg, new_arg in remap_args.items():
            if original_arg in class_params:
                class_params.remove(original_arg)
                class_params.add(new_arg)
                logging.info(f"Remapped {original_arg} -> {new_arg} in {cls.__name__}")

            if original_arg in dataclass_params:
                dataclass_params.remove(original_arg)
                dataclass_params.add(new_arg)
                logging.info(f"Remapped {original_arg} -> {new_arg} in {datacls.__name__}")

    if ignore_args is not None:
        ignore_args = set(ignore_args)

        class_params = class_params - ignore_args
        dataclass_params = dataclass_params - ignore_args
        logging.info(f"Removing ignored arguments - {ignore_args}")

    intersection = set.intersection(class_params, dataclass_params)
    subset_cls = class_params - intersection
    subset_datacls = dataclass_params - intersection

    if (len(class_params) != len(dataclass_params)) or len(subset_cls) > 0 or len(subset_datacls) > 0:
        logging.error(f"Class {cls.__name__} arguments do not match " f"Dataclass {datacls.__name__}!")

        if len(subset_cls) > 0:
            logging.error(f"Class {cls.__name__} has additional arguments :\n" f"{subset_cls}")

        if len(subset_datacls):
            logging.error(f"Dataclass {datacls.__name__} has additional arguments :\n{subset_datacls}")

        return False, subset_cls, subset_datacls

    else:
        return True, None, None
