from typing import Dict, Optional, Union

import numpy as np
import tensorflow as tf

from ray.air.util.data_batch_conversion import _unwrap_ndarray_object_type_if_needed


def convert_ndarray_to_tf_tensor(
    ndarray: np.ndarray,
    dtype: Optional[tf.dtypes.DType] = None,
    type_spec: Optional[tf.TypeSpec] = None,
) -> tf.Tensor:
    """Convert a NumPy ndarray to a TensorFlow Tensor.

    Args:
        ndarray: A NumPy ndarray that we wish to convert to a TensorFlow Tensor.
        dtype: A TensorFlow dtype for the created tensor; if None, the dtype will be
            inferred from the NumPy ndarray data.
        type_spec: A type spec that specifies the shape and dtype of the returned
            tensor. If you specify ``dtype``, the dtype stored in the type spec is
            ignored.

    Returns: A TensorFlow Tensor.
    """
    if dtype is None and type_spec is not None:
        dtype = type_spec.dtype

    is_ragged = isinstance(type_spec, tf.RaggedTensorSpec)
    ndarray = _unwrap_ndarray_object_type_if_needed(ndarray)
    if is_ragged:
        return tf.ragged.constant(ndarray, dtype=dtype)
    else:
        return tf.convert_to_tensor(ndarray, dtype=dtype)


def convert_ndarray_batch_to_tf_tensor_batch(
    ndarrays: Union[np.ndarray, Dict[str, np.ndarray]],
    dtypes: Optional[Union[tf.dtypes.DType, Dict[str, tf.dtypes.DType]]] = None,
) -> Union[tf.Tensor, Dict[str, tf.Tensor]]:
    """Convert a NumPy ndarray batch to a TensorFlow Tensor batch.

    Args:
        ndarray: A (dict of) NumPy ndarray(s) that we wish to convert to a TensorFlow
            Tensor.
        dtype: A (dict of) TensorFlow dtype(s) for the created tensor; if None, the
            dtype will be inferred from the NumPy ndarray data.

    Returns: A (dict of) TensorFlow Tensor(s).
    """
    if isinstance(ndarrays, np.ndarray):
        # Single-tensor case.
        if isinstance(dtypes, dict):
            if len(dtypes) != 1:
                raise ValueError(
                    "When constructing a single-tensor batch, only a single dtype "
                    f"should be given, instead got: {dtypes}"
                )
            dtypes = next(iter(dtypes.values()))
        batch = convert_ndarray_to_tf_tensor(ndarrays, dtypes)
    else:
        # Multi-tensor case.
        batch = {
            col_name: convert_ndarray_to_tf_tensor(
                col_ndarray,
                dtype=dtypes[col_name] if isinstance(dtypes, dict) else dtypes,
            )
            for col_name, col_ndarray in ndarrays.items()
        }

    return batch
