from collections import deque
import contextlib
from functools import partial
import tree
from typing import Any, Dict, List, Optional, Tuple, Union

from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.filter import Filter
from ray.rllib.utils.filter_manager import FilterManager
from ray.rllib.utils.framework import (
    try_import_jax,
    try_import_tf,
    try_import_tfp,
    try_import_torch,
)
from ray.rllib.utils.numpy import (
    sigmoid,
    softmax,
    relu,
    one_hot,
    fc,
    lstm,
    SMALL_NUMBER,
    LARGE_INTEGER,
    MIN_LOG_NN_OUTPUT,
    MAX_LOG_NN_OUTPUT,
)
from ray.rllib.utils.schedules import (
    LinearSchedule,
    PiecewiseSchedule,
    PolynomialSchedule,
    ExponentialSchedule,
    ConstantSchedule,
)
from ray.rllib.utils.test_utils import (
    check,
    check_compute_single_action,
    check_train_results,
)
from ray.tune.utils import merge_dicts, deep_update


@DeveloperAPI
def add_mixins(base, mixins, reversed=False):
    """Returns a new class with mixins applied in priority order."""

    mixins = list(mixins or [])

    while mixins:
        if reversed:

            class new_base(base, mixins.pop()):
                pass

        else:

            class new_base(mixins.pop(), base):
                pass

        base = new_base

    return base


@DeveloperAPI
def force_list(
    elements: Optional[Any] = None, to_tuple: bool = False
) -> Union[List, Tuple]:
    """
    Makes sure `elements` is returned as a list, whether `elements` is a single
    item, already a list, or a tuple.

    Args:
        elements: The inputs as a single item, a list/tuple/deque of items, or None,
            to be converted to a list/tuple. If None, returns empty list/tuple.
        to_tuple: Whether to use tuple (instead of list).

    Returns:
        The provided item in a list of size 1, or the provided items as a
        list. If `elements` is None, returns an empty list. If `to_tuple` is True,
        returns a tuple instead of a list.
    """
    ctor = list
    if to_tuple is True:
        ctor = tuple
    return (
        ctor()
        if elements is None
        else ctor(elements)
        if type(elements) in [list, set, tuple, deque]
        else ctor([elements])
    )


@DeveloperAPI
def flatten_dict(nested: Dict[str, Any], sep="/", env_steps=0) -> Dict[str, Any]:
    """
    Flattens a nested dict into a flat dict with joined keys.

    Note, this is used for better serialization of nested dictionaries
    in `OfflinePreLearner.__call__` when called inside
    `ray.data.Dataset.map_batches`.

    Note, this is used to return a `Dict[str, numpy.ndarray] from the
    `__call__` method which is expected by Ray Data.

    Args:
        nested: A nested dictionary.
        sep: Separator to use when joining keys.

    Returns:
        A flat dictionary where each key is a path of keys in the nested dict.
    """
    flat = {}
    # `dm_tree.flatten_with_path`` returns a list of `(path, leaf)` tuples.
    for path, leaf in tree.flatten_with_path(nested):
        # Create a single string key from the path.
        key = sep.join(map(str, path))
        flat[key] = leaf

    return flat


@DeveloperAPI
def unflatten_dict(flat: Dict[str, Any], sep="/") -> Dict[str, Any]:
    """
    Reconstructs a nested dict from a flat dict with joined keys.

    Note, this is used for better deserialization ofr nested dictionaries
    in `Learner.update' calls in which a `ray.data.DataIterator` is used.

    Args:
        flat: A flat dictionary with keys that are paths joined by `sep`.
        sep: The separator used in the flat dictionary keys.

    Returns:
        A nested dictionary.
    """
    nested = {}
    for compound_key, value in flat.items():
        # Split all keys by the separator.
        keys = compound_key.split(sep)
        current = nested
        # Nest by the separated keys.
        for key in keys[:-1]:
            if key not in current:
                current[key] = {}
            current = current[key]
        current[keys[-1]] = value

    return nested


@DeveloperAPI
class NullContextManager(contextlib.AbstractContextManager):
    """No-op context manager"""

    def __init__(self):
        pass

    def __enter__(self):
        pass

    def __exit__(self, *args):
        pass


force_tuple = partial(force_list, to_tuple=True)

__all__ = [
    "add_mixins",
    "check",
    "check_compute_single_action",
    "check_train_results",
    "deep_update",
    "deprecation_warning",
    "fc",
    "force_list",
    "force_tuple",
    "lstm",
    "merge_dicts",
    "one_hot",
    "override",
    "relu",
    "sigmoid",
    "softmax",
    "try_import_jax",
    "try_import_tf",
    "try_import_tfp",
    "try_import_torch",
    "ConstantSchedule",
    "DeveloperAPI",
    "ExponentialSchedule",
    "Filter",
    "FilterManager",
    "LARGE_INTEGER",
    "LinearSchedule",
    "MAX_LOG_NN_OUTPUT",
    "MIN_LOG_NN_OUTPUT",
    "PiecewiseSchedule",
    "PolynomialSchedule",
    "PublicAPI",
    "SMALL_NUMBER",
]
